summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 20:37:06 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 20:38:05 +0200
commitd0eb9a1edfb4233d557d954f46685f25382234d8 (patch)
tree04eb5a746258fcaa2a3b98228c6eadb2b0178ba3 /src/Compile.hs
parent4ad7eaba73d5fda8ff5028d1e53966f728d704d3 (diff)
Reorder TLEither to after TEither
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs72
1 files changed, 36 insertions, 36 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 6ba3a39..cd10831 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -221,6 +221,7 @@ genStructName = \t -> "ty_" ++ gen t where
gen STNil = "n"
gen (STPair a b) = 'P' : gen a ++ gen b
gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
gen (STMaybe t) = 'M' : gen t
gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
gen (STScal st) = case st of
@@ -230,7 +231,6 @@ genStructName = \t -> "ty_" ++ gen t where
STF64 -> "d"
STBool -> "b"
gen (STAccum t) = 'C' : gen (fromSMTy t)
- gen (STLEither a b) = 'L' : gen a ++ gen b
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
@@ -246,6 +246,8 @@ genStruct name topty = case topty of
[StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
STEither a b -> -- 0 -> l, 1 -> r
[StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
STMaybe t -> -- 0 -> nothing, 1 -> just
[StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
STArr n t ->
@@ -259,8 +261,6 @@ genStruct name topty = case topty of
STAccum t ->
[StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
,StructDecl name (name ++ "_buf *buf;") com]
- STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
where
com = ppSTy 0 topty
@@ -282,11 +282,11 @@ genStructs ty = do
STNil -> pure ()
STPair a b -> genStructs a >> genStructs b
STEither a b -> genStructs a >> genStructs b
+ STLEither a b -> genStructs a >> genStructs b
STMaybe t -> genStructs t
STArr _ t -> genStructs t
STScal _ -> pure ()
STAccum t -> genStructs (fromSMTy t)
- STLEither a b -> genStructs a >> genStructs b
tell (BList (genStruct name ty))
@@ -463,6 +463,15 @@ serialise topty topval ptr off k =
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
(STMaybe _, Nothing) -> do
pokeByteOff ptr off (0 :: Word8)
k
@@ -493,15 +502,6 @@ serialise topty topval ptr off k =
STF64 -> pokeByteOff ptr off (x :: Double) >> k
STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
(STAccum{}, _) -> error "Cannot serialise accumulators"
- (STLEither _ _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STLEither a _, Just (Left x)) -> do
- pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STLEither _ b, Just (Right y)) -> do
- pokeByteOff ptr off (2 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
-- | Assumes that this is called at the correct alignment.
deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
@@ -518,6 +518,13 @@ deserialise topty ptr off =
if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
STMaybe t -> do
tag <- peekByteOff @Word8 ptr off
if tag == 0
@@ -541,13 +548,6 @@ deserialise topty ptr off =
STF64 -> peekByteOff @Double ptr off
STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
STAccum{} -> error "Cannot serialise accumulators"
- STLEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
- 0 -> return Nothing
- 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
- 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
- _ -> error "Invalid tag value"
align :: Int -> Int -> Int
align a off = (off + a - 1) `div` a * a
@@ -569,6 +569,10 @@ metricsSTy (STEither a b) =
let (a1, s1) = metricsSTy a
(a2, s2) = metricsSTy b
in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
@@ -580,10 +584,6 @@ metricsSTy (STScal sty) = case sty of
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
-metricsSTy (STLEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -1071,8 +1071,8 @@ compile' env = \case
incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
emit $ SAsg v (CELit addend)
-- sparse types
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
(SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
-- dense types
(SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
f (v++".a") (i++".a")
@@ -1303,13 +1303,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
-makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
- (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
@@ -1657,6 +1657,14 @@ zeroRefcountCheck toptyp opname topvar =
go (STEither a b) path = do
(s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
go (STMaybe a) path = do
ss <- go a (path++".j")
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
@@ -1673,14 +1681,6 @@ zeroRefcountCheck toptyp opname topvar =
return (BList [s1, s2, s3])
go STScal{} _ = empty
go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
- go (STLEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $
- SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
- s1
- (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
- s2
- mempty))
combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
combine (MaybeT a) (MaybeT b) = MaybeT $ do