summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 15:54:12 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 15:54:12 +0200
commit3fd8d35cca2a23c137934a170c67e8ce310edf13 (patch)
tree429fb99f9c1395272f1f9a94bfbc0e003fa39b21 /src/Compile.hs
parent919a36f8eed21501357185a90e2b7a4d9eaf7f08 (diff)
Complete monoidal accumulator rewrite
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs205
1 files changed, 128 insertions, 77 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 503c342..6ba3a39 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1001,16 +1001,6 @@ compile' env = \case
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
EAccum _ t prj eidx eval eacc -> do
- nameidx <- compileAssign "acidx" env eidx
- nameval <- compileAssign "acval" env eval
-
- -- Generate the variable manually because this one has to be non-const.
- -- TODO: old code:
- -- eacc' <- compile' env eacc
- -- nameacc <- genName' "acac"
- -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
- nameacc <- compileAssign "acac" env eacc
-
let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
-- full zero array with the given zero info (for the type SMTArr n t1).
initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
@@ -1041,77 +1031,66 @@ compile' env = \case
initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
initZeroFromMemset SMTScal{} = Nothing
- let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZero :: SMTy a -> String -> String -> CompM ()
- initZero SMTNil _ _ = return ()
- initZero (SMTPair t1 t2) v vzi = do
- initZero t1 (v++".a") (vzi++".a")
- initZero t2 (v++".b") (vzi++".b")
- initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZero (SMTScal sty) v _ = case sty of
+ let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
+ initZeroZI :: SMTy a -> String -> String -> CompM ()
+ initZeroZI SMTNil _ _ = return ()
+ initZeroZI (SMTPair t1 t2) v vzi = do
+ initZeroZI t1 (v++".a") (vzi++".a")
+ initZeroZI t2 (v++".b") (vzi++".b")
+ initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
+ initZeroZI (SMTScal sty) v _ = case sty of
STI32 -> emit $ SAsg v (CELit "0")
STI64 -> emit $ SAsg v (CELit "0l")
STF32 -> emit $ SAsg v (CELit "0.0f")
STF64 -> emit $ SAsg v (CELit "0.0")
- let -- | Dereference an accumulation value. Sparse components encountered
- -- along the way are initialised before proceeding downwards. At the
- -- point where we have the projected accumulator position available,
- -- the handler will be invoked with a variable name pointing to the
- -- projected position.
- -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler)
- accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM ()
- accumRef _ SAPHere v _ k = k v
-
- accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k
- accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k
-
- accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do
- ((), stmtsInit1) <- scope $ initZero ta (v++".l") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef ta prj' (v++".l") i k
- accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do
- ((), stmtsInit2) <- scope $ initZero tb (v++".r") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty
- accumRef tb prj' (v++".r") i k
-
- accumRef (SMTMaybe tj) (SAPJust prj') v i k = do
- ((), stmtsInit1) <- scope $ initZero tj (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i k
-
- accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
- let a .||. b = CEBinop a "||" b
- emit $ SIf (CEBinop ixcomp "<" (CELit "0")
- .||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
- "); " ++
- "return false;")
- mempty
-
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k
+ let -- Initialise an uninitialised accumulation value, potentially already
+ -- with the addend, potentially to zero depending on the nature of the
+ -- projection.
+ -- 1. If the projection indexes only through dense monoids before
+ -- reaching SAPHere, the thing cannot be initialised to zero with
+ -- only an AcIdx; it would need to model a zero after the addend,
+ -- which is stupid and redundant. In this case, we return Left:
+ -- (accumulation value) (AcIdx value) (addend value).
+ -- The addend is copied, not consumed. (We can't reliably _always_
+ -- consume it, so it's not worth trying to do it sometimes.)
+ -- 2. Otherwise, a sparse monoid is found along the way, and we can
+ -- initalise the dense prefix of the path to zero by setting the
+ -- indexed-through sparse value to a sparse zero. Afterwards, the
+ -- main recursion can proceed further. In this case, we return
+ -- Right: (accumulation value) (AcIdx value)
+ -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
+ initZeroChunk :: SMTy a -> SAcPrj p a b
+ -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
+ (String -> String -> CompM ()) -- zero initialisation of sparse chunk
+ initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
+ -- reached target before the first sparse constructor
+ (t1 , SAPHere ) -> Left $ \v _ addend -> do
+ incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
+ emit $ SAsg v (CELit addend)
+ -- sparse types
+ (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ -- dense types
+ (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
+ f (v++".a") (i++".a")
+ initZeroZI t2 (v++".b") (i++".b")
+ (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
+ initZeroZI t1 (v++".a") (i++".a")
+ f (v++".b") (i++".b")
+ (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
+ initZeroArray n t1 v (i++".a.b")
+ linidxvar <- genName' "li"
+ emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
+ f (v++".buf->xs["++linidxvar++"]") (i++".b")
+ where
+ applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
+ applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
let -- Add a value (s) into an existing accumulation value (d). If a sparse
- -- component of d is encountered, s is simply written there.
+ -- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
add SMTNil _ _ = return ()
add (SMTPair t1 t2) d s = do
@@ -1165,16 +1144,88 @@ compile' env = \case
mempty
shsizename <- genName' "acshsz"
- emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
ivar <- genName' "i"
((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
stmts1
add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
+
+ accumRef (SMTLEither ta tb) prj0 v i addend = do
+ let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
+ SAPRight prj' -> initZeroChunk tb prj'
+ subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
+ tagval = case prj0 of SAPLeft{} -> "1"
+ SAPRight{} -> "2"
+ ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
+ SAPRight prj' -> accumRef tb prj' subv i addend
+ case chunkres of
+ Left densef -> do
+ ((), stmtsSet) <- scope $ densef subv i addend
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
+ stmtsAdd -- TODO: emit check for consistency of tags?
+ Right sparsef -> do
+ ((), stmtsInit) <- scope $ sparsef subv i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
+ forM_ stmtsAdd emit
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ case initZeroChunk tj prj' of
+ Left densef -> do
+ ((), stmtsSet1) <- scope $ densef (v++".j") i addend
+ ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
+ stmtsAdd1
+ Right sparsef -> do
+ ((), stmtsInit1) <- scope $ sparsef (v++".j") i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip3 [0::Int ..]
+ (indexTupleComponents n (i++".a.a"))
+ (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ .||.
+ CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
+
emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- accumRef t prj (nameacc++".buf->ac") nameidx $ \dest ->
- add (acPrjTy prj t) dest nameval
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
emit $ SVerbatim $ "// compile EAccum end"
incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval