summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs171
1 files changed, 38 insertions, 133 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 722b432..a5c4fb7 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -45,6 +45,7 @@ import qualified Prelude
import Array
import AST
import AST.Pretty (ppSTy, ppExpr)
+import AST.Sparse.Types (isDense)
import Compile.Exec
import Data
import Interpreter.Rep
@@ -1002,95 +1003,7 @@ compile' env = \case
rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
- EAccum _ t prj eidx eval eacc -> do
- let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
- -- full zero array with the given zero info (for the type SMTArr n t1).
- initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
- initZeroArray n t1 v vzi = do
- shszname <- genName' "inacshsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi)
- newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi)
- emit $ SAsg v (CELit newarrName)
- forM_ (initZeroFromMemset t1) $ \f1 -> do
- ivar <- genName' "i"
- ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts
-
- -- If something needs to be done to properly initialise this type to
- -- zero after memory has already been initialised to all-zero bytes,
- -- returns an action that does so.
- -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ())
- initZeroFromMemset SMTNil = Nothing
- initZeroFromMemset (SMTPair t1 t2) =
- case (initZeroFromMemset t1, initZeroFromMemset t2) of
- (Nothing, Nothing) -> Nothing
- (mf1, mf2) -> Just $ \v vzi -> do
- forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a")
- forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b")
- initZeroFromMemset SMTLEither{} = Nothing
- initZeroFromMemset SMTMaybe{} = Nothing
- initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
- initZeroFromMemset SMTScal{} = Nothing
-
- let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroZI :: SMTy a -> String -> String -> CompM ()
- initZeroZI SMTNil _ _ = return ()
- initZeroZI (SMTPair t1 t2) v vzi = do
- initZeroZI t1 (v++".a") (vzi++".a")
- initZeroZI t2 (v++".b") (vzi++".b")
- initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZeroZI (SMTScal sty) v _ = case sty of
- STI32 -> emit $ SAsg v (CELit "0")
- STI64 -> emit $ SAsg v (CELit "0l")
- STF32 -> emit $ SAsg v (CELit "0.0f")
- STF64 -> emit $ SAsg v (CELit "0.0")
-
- let -- Initialise an uninitialised accumulation value, potentially already
- -- with the addend, potentially to zero depending on the nature of the
- -- projection.
- -- 1. If the projection indexes only through dense monoids before
- -- reaching SAPHere, the thing cannot be initialised to zero with
- -- only an AcIdx; it would need to model a zero after the addend,
- -- which is stupid and redundant. In this case, we return Left:
- -- (accumulation value) (AcIdx value) (addend value).
- -- The addend is copied, not consumed. (We can't reliably _always_
- -- consume it, so it's not worth trying to do it sometimes.)
- -- 2. Otherwise, a sparse monoid is found along the way, and we can
- -- initalise the dense prefix of the path to zero by setting the
- -- indexed-through sparse value to a sparse zero. Afterwards, the
- -- main recursion can proceed further. In this case, we return
- -- Right: (accumulation value) (AcIdx value)
- -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
- initZeroChunk :: SMTy a -> SAcPrj p a b
- -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
- (String -> String -> CompM ()) -- zero initialisation of sparse chunk
- initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
- -- reached target before the first sparse constructor
- (t1 , SAPHere ) -> Left $ \v _ addend -> do
- incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
- emit $ SAsg v (CELit addend)
- -- sparse types
- (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- -- dense types
- (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- f (v++".a") (i++".a")
- initZeroZI t2 (v++".b") (i++".b")
- (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
- initZeroZI t1 (v++".a") (i++".a")
- f (v++".b") (i++".b")
- (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- initZeroArray n t1 v (i++".a.b")
- linidxvar <- genName' "li"
- emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
- f (v++".buf->xs["++linidxvar++"]") (i++".b")
- where
- applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
- applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
-
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
let -- Add a value (s) into an existing accumulation value (d). If a sparse
-- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
@@ -1160,67 +1073,55 @@ compile' env = \case
accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
- accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
- accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
-
- accumRef (SMTLEither ta tb) prj0 v i addend = do
- let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
- SAPRight prj' -> initZeroChunk tb prj'
- subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
- tagval = case prj0 of SAPLeft{} -> "1"
- SAPRight{} -> "2"
- ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
- SAPRight prj' -> accumRef tb prj' subv i addend
- case chunkres of
- Left densef -> do
- ((), stmtsSet) <- scope $ densef subv i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
- stmtsAdd -- TODO: emit check for consistency of tags?
- Right sparsef -> do
- ((), stmtsInit) <- scope $ sparsef subv i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
- forM_ stmtsAdd emit
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
+
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
- case initZeroChunk tj prj' of
- Left densef -> do
- ((), stmtsSet1) <- scope $ densef (v++".j") i addend
- ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
- stmtsAdd1
- Right sparsef -> do
- ((), stmtsInit1) <- scope $ sparsef (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i addend
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
v ++ ".buf" ++
concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
"); " ++
"return false;")
mempty
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
nameidx <- compileAssign "acidx" env eidx
nameval <- compileAssign "acval" env eval
@@ -1234,6 +1135,9 @@ compile' env = \case
return $ CEStruct (repSTy STNil) []
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
EError _ t s -> do
let padleft len c s' = replicate (len - length s) c ++ s'
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
@@ -1247,6 +1151,7 @@ compile' env = \case
return $ CEStruct name []
EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"