summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs441
1 files changed, 275 insertions, 166 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index e2d004a..503c342 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -45,7 +45,6 @@ import qualified Prelude
import Array
import AST
import AST.Pretty (ppSTy, ppExpr)
-import qualified CHAD.Types as CHAD
import Compile.Exec
import Data
import Interpreter.Rep
@@ -230,11 +229,15 @@ genStructName = \t -> "ty_" ++ gen t where
STF32 -> "f"
STF64 -> "d"
STBool -> "b"
- gen (STAccum t) = 'C' : gen t
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
+ gen (STLEither a b) = 'L' : gen a ++ gen b
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
-- types in the C translation.
+--
+-- For accumulation it is important that for struct representations of monoid
+-- types, the all-zero-bytes value corresponds to the zero value of that type.
genStruct :: String -> STy t -> [StructDecl]
genStruct name topty = case topty of
STNil ->
@@ -247,13 +250,17 @@ genStruct name topty = case topty of
[StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
+ -- TODO: put shape in the main struct, not the buffer; it's constant, after all
+ -- TODO: no buffer if n = 0
[StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
,StructDecl name (name ++ "_buf *buf;") com]
STScal _ ->
[]
STAccum t ->
- [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") ""
+ [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
,StructDecl name (name ++ "_buf *buf;") com]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
where
com = ppSTy 0 topty
@@ -278,7 +285,8 @@ genStructs ty = do
STMaybe t -> genStructs t
STArr _ t -> genStructs t
STScal _ -> pure ()
- STAccum t -> genStructs (CHAD.d2 t)
+ STAccum t -> genStructs (fromSMTy t)
+ STLEither a b -> genStructs a >> genStructs b
tell (BList (genStruct name ty))
@@ -450,7 +458,7 @@ serialise topty topval ptr off k =
serialise a x ptr off $
serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
(STEither a _, Left x) -> do
- pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b})
+ pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
serialise a x ptr (off + alignmentSTy topty) k
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
@@ -485,6 +493,15 @@ serialise topty topval ptr off k =
STF64 -> pokeByteOff ptr off (x :: Double) >> k
STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
(STAccum{}, _) -> error "Cannot serialise accumulators"
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
-- | Assumes that this is called at the correct alignment.
deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
@@ -498,7 +515,7 @@ deserialise topty ptr off =
return (x, y)
STEither a b -> do
tag <- peekByteOff @Word8 ptr off
- if tag == 0 -- alignment of (a + b) is alignment of (union {a b})
+ if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
STMaybe t -> do
@@ -524,6 +541,13 @@ deserialise topty ptr off =
STF64 -> peekByteOff @Double ptr off
STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
STAccum{} -> error "Cannot serialise accumulators"
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
align :: Int -> Int -> Int
align a off = (off + a - 1) `div` a * a
@@ -555,7 +579,11 @@ metricsSTy (STScal sty) = case sty of
STF32 -> (4, 4)
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
-metricsSTy (STAccum t) = metricsSTy t
+metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -685,6 +713,39 @@ compile' env = \case
<> pure (SAsg retvar e3))))
return (CELit retvar)
+ ELNil _ t1 t2 -> do
+ name <- emitStruct (STLEither t1 t2)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ ELInl _ t e -> do
+ name <- emitStruct (STLEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
+
+ ELInr _ t e -> do
+ name <- emitStruct (STLEither t (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
+
+ ELCase _ e a b c -> do
+ let STLEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
+ (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
+ ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2 <> pure (SAsg retvar e2))
+ (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
+ (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
+ (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
+ return (CELit retvar)
+
EConstArr _ n t (Array sh vec) -> do
strname <- emitStruct (STArr n (STScal t))
tldname <- genName' "carraybuf"
@@ -734,8 +795,7 @@ compile' env = \case
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
- resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname))
- [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
@@ -781,8 +841,7 @@ compile' env = \case
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname))
- [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
@@ -833,8 +892,7 @@ compile' env = \case
resname <- allocArray "repl1i" Malloc "rep" (SS n) t
(Just (CEBinop (CELit shszname) "*" (CELit lenname)))
- ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
- ++ [CELit lenname])
+ (compileArrShapeComponents n argname ++ [CELit lenname])
ivar <- genName' "i"
jvar <- genName' "j"
@@ -926,20 +984,20 @@ compile' env = \case
zeroRefcountCheck (typeOf e1) "with" name1
emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
- mcopy <- copyForWriting (CHAD.d2 t) name1
+ mcopy <- copyForWriting t name1
accname <- genName' "accum"
emit $ SVarDecl False actyname accname
- (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (CHAD.d2 t)))])])
+ (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
e2' <- compile' (Const accname `SCons` env) e2
resname <- genName' "acret"
- emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac"))
+ emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
- rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))
+ rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
EAccum _ t prj eidx eval eacc -> do
@@ -947,156 +1005,180 @@ compile' env = \case
nameval <- compileAssign "acval" env eval
-- Generate the variable manually because this one has to be non-const.
- eacc' <- compile' env eacc
- nameacc <- genName' "acac"
- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
-
- let -- Expects a variable reference to a value of type @D2 a@.
- setZero :: STy a -> String -> CompM ()
- setZero STNil _ = return ()
- setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b))
- setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b))
- setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a)
- setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a))
- setZero (STScal sty) v = case sty of
- STI32 -> return () -- Nil
- STI64 -> return () -- Nil
+ -- TODO: old code:
+ -- eacc' <- compile' env eacc
+ -- nameacc <- genName' "acac"
+ -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
+ nameacc <- compileAssign "acac" env eacc
+
+ let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
+ -- full zero array with the given zero info (for the type SMTArr n t1).
+ initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
+ initZeroArray n t1 v vzi = do
+ shszname <- genName' "inacshsz"
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi)
+ newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi)
+ emit $ SAsg v (CELit newarrName)
+ forM_ (initZeroFromMemset t1) $ \f1 -> do
+ ivar <- genName' "i"
+ ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]")
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts
+
+ -- If something needs to be done to properly initialise this type to
+ -- zero after memory has already been initialised to all-zero bytes,
+ -- returns an action that does so.
+ -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
+ initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ())
+ initZeroFromMemset SMTNil = Nothing
+ initZeroFromMemset (SMTPair t1 t2) =
+ case (initZeroFromMemset t1, initZeroFromMemset t2) of
+ (Nothing, Nothing) -> Nothing
+ (mf1, mf2) -> Just $ \v vzi -> do
+ forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a")
+ forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b")
+ initZeroFromMemset SMTLEither{} = Nothing
+ initZeroFromMemset SMTMaybe{} = Nothing
+ initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
+ initZeroFromMemset SMTScal{} = Nothing
+
+ let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
+ initZero :: SMTy a -> String -> String -> CompM ()
+ initZero SMTNil _ _ = return ()
+ initZero (SMTPair t1 t2) v vzi = do
+ initZero t1 (v++".a") (vzi++".a")
+ initZero t2 (v++".b") (vzi++".b")
+ initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
+ initZero (SMTScal sty) v _ = case sty of
+ STI32 -> emit $ SAsg v (CELit "0")
+ STI64 -> emit $ SAsg v (CELit "0l")
STF32 -> emit $ SAsg v (CELit "0.0f")
STF64 -> emit $ SAsg v (CELit "0.0")
- STBool -> return () -- Nil
- setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported"
- initD2Pair :: STy a -> STy b -> String -> CompM ()
- initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b))
- ((), stmts1) <- scope $ setZero a (v++".j.a")
- ((), stmts2) <- scope $ setZero b (v++".j.b")
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2)
- mempty
+ let -- | Dereference an accumulation value. Sparse components encountered
+ -- along the way are initialised before proceeding downwards. At the
+ -- point where we have the projected accumulator position available,
+ -- the handler will be invoked with a variable name pointing to the
+ -- projected position.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM ()
+ accumRef _ SAPHere v _ k = k v
- initD2Either :: STy a -> STy b -> String -> Either () () -> CompM ()
- initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b))
- ((), stmts) <- case side of
- Left () -> scope $ setZero a (v++".j.l")
- Right () -> scope $ setZero b (v++".j.r")
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
- mempty
+ accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k
- initD2Maybe :: STy a -> String -> CompM ()
- initD2Maybe a v = do -- Maybe (D2 a)
- ((), stmts) <- scope $ setZero a (v++".j")
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do
+ ((), stmtsInit1) <- scope $ initZero ta (v++".l") i
emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
- mempty
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
+ accumRef ta prj' (v++".l") i k
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do
+ ((), stmtsInit2) <- scope $ initZero tb (v++".r") i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty
+ accumRef tb prj' (v++".r") i k
- -- mind: this has to traverse the D2 of these things, and it also has to
- -- initialise data structures that are still sparse in the accumulator.
- let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String
- accumRef _ SAPHere v _ = pure v
- accumRef (STPair ta tb) (SAPFst prj') v i = do
- initD2Pair ta tb v
- accumRef ta prj' (v++".j.a") i
- accumRef (STPair ta tb) (SAPSnd prj') v i = do
- initD2Pair ta tb v
- accumRef tb prj' (v++".j.b") i
- accumRef (STEither ta tb) (SAPLeft prj') v i = do
- initD2Either ta tb v (Left ())
- accumRef ta prj' (v++".j.l") i
- accumRef (STEither ta tb) (SAPRight prj') v i = do
- initD2Either ta tb v (Right ())
- accumRef tb prj' (v++".j.r") i
- accumRef (STMaybe tj) (SAPJust prj') v i = do
- initD2Maybe tj v
- accumRef tj prj' (v++".j") i
- accumRef (STArr n t') (SAPArrIdx prj' _) v i = do
- (newarrName, newarrStmts) <- scope $ allocArray "accumRef" Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b"))
+ accumRef (SMTMaybe tj) (SAPJust prj') v i k = do
+ ((), stmtsInit1) <- scope $ initZero tj (v++".j") i
emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1"))
- <> newarrStmts
- <> pure (SAsg (v++".j") (CELit newarrName)))
- mempty
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
+ accumRef tj prj' (v++".j") i k
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
forM_ (zip3 [0::Int ..]
(indexTupleComponents n (i++".a.a"))
- (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
.||.
- CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))))
+ CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
"arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".j.buf" ++
- concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
"); " ++
"return false;")
mempty
- accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b")
-
- -- mind: this has to add the D2 of these things, and it also has to
- -- initialise data structures that are still sparse in the accumulator.
- let add :: STy a -> String -> String -> CompM ()
- add STNil _ _ = return ()
- add (STPair t1 t2) d s = do
- ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a")
- ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (stmts1 <> stmts2)
- mempty))
- add (STEither t1 t2) d s = do
- ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l")
- ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r")
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k
+
+ let -- Add a value (s) into an existing accumulation value (d). If a sparse
+ -- component of d is encountered, s is simply written there.
+ add :: SMTy a -> String -> String -> CompM ()
+ add SMTNil _ _ = return ()
+ add (SMTPair t1 t2) d s = do
+ add t1 (d++".a") (s++".a")
+ add t2 (d++".b") (s++".b")
+ add (SMTLEither t1 t2) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
+ ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
+ ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (pure (SAsg (d++".j.tag") (CELit (s++".j.tag")))
- <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0"))
- stmts1 stmts2))
- mempty))
- add (STMaybe t1) d s = do
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ ((if emitChecks
+ then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
+ "&&"
+ (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
+ "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
+ "return false;")
+ mempty)
+ else mempty)
+ -- note: s may have tag 0
+ <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ stmts1
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
+ stmts2 mempty))))
+ add (SMTMaybe t1) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (pure (SAsg (d++".tag") (CELit "1")) <> stmts1)
- mempty))
- add (STArr n t1) d s = do
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
+ add (SMTArr n t1) d s = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ [0 .. fromSNat n - 1] $ \j -> do
+ emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
+ "!="
+ (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
+ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
+ d ++ ".buf" ++
+ concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ ", " ++ s ++ ".buf" ++
+ concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ "); " ++
+ "return false;")
+ mempty
+
shsizename <- genName' "acshsz"
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))
ivar <- genName' "i"
- ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]")
- ((), stmtsDecr) <- scope $ incrementVarAlways "accumarr" Decrement (STArr n (CHAD.d2 t1)) (s++".j")
- emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")))
- -- TODO: emit check here for the source being either equal in shape to the destination
- <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
- stmts1)
- <> stmtsDecr)))
- mempty
- add (STScal sty) d s = case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STBool -> return ()
- add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
+ ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+ stmts1
+ add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- dest <- accumRef t prj (nameacc++".buf->ac") nameidx
- add (acPrjTy prj t) dest nameval
+ accumRef t prj (nameacc++".buf->ac") nameidx $ \dest ->
+ add (acPrjTy prj t) dest nameval
emit $ SVerbatim $ "// compile EAccum end"
+ incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
+
return $ CEStruct (repSTy STNil) []
EError _ t s -> do
@@ -1111,9 +1193,9 @@ compile' env = \case
name <- emitStruct t
return $ CEStruct name []
- EZero{} -> error "Compile: monoid operations should have been eliminated"
- EPlus{} -> error "Compile: monoid operations should have been eliminated"
- EOneHot{} -> error "Compile: monoid operations should have been eliminated"
+ EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EIdx1{} -> error "Compile: not implemented: EIdx1"
@@ -1144,6 +1226,7 @@ data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array
| ATNoop -- ^ don't do anything here
| ATProj String ArrayTree -- ^ descend one field deeper
| ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
+ | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
| ATBoth ArrayTree ArrayTree -- ^ do both these paths
smartATProj :: String -> ArrayTree -> ArrayTree
@@ -1154,6 +1237,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
smartATCondTag ATNoop ATNoop = ATNoop
smartATCondTag t t' = ATCondTag t t'
+smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
+smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
+smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
+
smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
smartATBoth ATNoop t = t
smartATBoth t ATNoop = t
@@ -1169,6 +1256,9 @@ makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTre
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
@@ -1204,6 +1294,15 @@ incrementVar' marker inc path (ATCondTag t1 t2) = do
((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
+incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
+ stmts2
+ (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
+ stmts3
+ stmts1))
incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
toLinearIdx :: SNat n -> String -> String -> CExpr
@@ -1257,10 +1356,12 @@ compileShapeQuery (SS n) var =
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize SZ _ = CELit "1"
-compileArrShapeSize n var =
- foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]")
- | i <- [0 .. fromSNat n - 1]]
+compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeComponents :: SNat n -> String -> [CExpr]
+compileArrShapeComponents n var =
+ [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
indexTupleComponents :: SNat n -> String -> [CExpr]
indexTupleComponents = \n var -> map CELit (toList (go n var))
@@ -1347,8 +1448,7 @@ compileExtremum nameBase opName operator env e = do
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname))
- [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
@@ -1375,47 +1475,47 @@ compileExtremum nameBase opName operator env e = do
-- | If this returns Nothing, there was nothing to copy because making a simple
-- value copy in C already makes it suitable to write to.
-copyForWriting :: STy t -> String -> CompM (Maybe CExpr)
+copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
copyForWriting topty var = case topty of
- STNil -> return Nothing
+ SMTNil -> return Nothing
- STPair a b -> do
+ SMTPair a b -> do
e1 <- copyForWriting a (var ++ ".a")
e2 <- copyForWriting b (var ++ ".b")
case (e1, e2) of
(Nothing, Nothing) -> return Nothing
- _ -> return $ Just $ CEStruct (repSTy topty)
+ _ -> return $ Just $ CEStruct toptyname
[("a", fromMaybe (CELit (var++".a")) e1)
,("b", fromMaybe (CELit (var++".b")) e2)]
- STEither a b -> do
+ SMTLEither a b -> do
(e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
(e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
case (e1, e2) of
(Nothing, Nothing) -> return Nothing
_ -> do
name <- genName
- emit $ SVarDeclUninit (repSTy topty) name
+ emit $ SVarDeclUninit toptyname name
emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
(stmts1
- <> pure (SAsg name (CEStruct (repSTy topty)
+ <> pure (SAsg name (CEStruct toptyname
[("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
(stmts2
- <> pure (SAsg name (CEStruct (repSTy topty)
+ <> pure (SAsg name (CEStruct toptyname
[("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
return (Just (CELit name))
- STMaybe t -> do
+ SMTMaybe t -> do
(e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
case e1 of
Nothing -> return Nothing
Just e1' -> do
name <- genName
- emit $ SVarDeclUninit (repSTy topty) name
+ emit $ SVarDeclUninit toptyname name
emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
- (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))
+ (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
(stmts1
- <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')])))
+ <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
return (Just (CELit name))
-- If there are no nested arrays, we know that a refcount of 1 means that the
@@ -1423,10 +1523,10 @@ copyForWriting topty var = case topty of
-- nesting we'd have to check the refcounts of all the nested arrays _too_;
-- let's not do that. Furthermore, no sub-arrays means that the whole thing
-- is flat, and we can just memcpy if necessary.
- STArr n t | not (hasArrays t) -> do
+ SMTArr n t | not (hasArrays (fromSMTy t)) -> do
name <- genName
shszname <- genName' "shsz"
- emit $ SVarDeclUninit (repSTy (STArr n t)) name
+ emit $ SVarDeclUninit toptyname name
when debugShapes $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
@@ -1438,11 +1538,11 @@ copyForWriting topty var = case topty of
emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
(pure (SAsg name (CELit var)))
(let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
in BList
[SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
- ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])])
+ ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
show shbytes ++ ");"
,SAsg (name ++ ".buf->refc") (CELit "1")
@@ -1450,26 +1550,26 @@ copyForWriting topty var = case topty of
printCExpr 0 databytes ");"])
return (Just (CELit name))
- STArr n t -> do
+ SMTArr n t -> do
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
name <- genName
- emit $ SVarDecl False (repSTy (STArr n t)) name
- (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])])
+ emit $ SVarDecl False toptyname name
+ (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
show shbytes ++ ");"
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-- put the arrays in variables to cut short the not-quite-var chain
dstvar <- genName' "cpydst"
- emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs"))
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
srcvar <- genName' "cpysrc"
- emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs"))
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
ivar <- genName' "i"
@@ -1484,9 +1584,10 @@ copyForWriting topty var = case topty of
return (Just (CELit name))
- STScal _ -> return Nothing
+ SMTScal _ -> return Nothing
- STAccum _ -> error "Compile: Nested accumulators not supported"
+ where
+ toptyname = repSTy (fromSMTy topty)
zeroRefcountCheck :: STy t -> String -> String -> CompM ()
zeroRefcountCheck toptyp opname topvar =
@@ -1521,6 +1622,14 @@ zeroRefcountCheck toptyp opname topvar =
return (BList [s1, s2, s3])
go STScal{} _ = empty
go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
combine (MaybeT a) (MaybeT b) = MaybeT $ do