diff options
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 205 |
1 files changed, 128 insertions, 77 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 503c342..6ba3a39 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1001,16 +1001,6 @@ compile' env = \case return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] EAccum _ t prj eidx eval eacc -> do - nameidx <- compileAssign "acidx" env eidx - nameval <- compileAssign "acval" env eval - - -- Generate the variable manually because this one has to be non-const. - -- TODO: old code: - -- eacc' <- compile' env eacc - -- nameacc <- genName' "acac" - -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' - nameacc <- compileAssign "acac" env eacc - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a -- full zero array with the given zero info (for the type SMTArr n t1). initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () @@ -1041,77 +1031,66 @@ compile' env = \case initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi initZeroFromMemset SMTScal{} = Nothing - let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZero :: SMTy a -> String -> String -> CompM () - initZero SMTNil _ _ = return () - initZero (SMTPair t1 t2) v vzi = do - initZero t1 (v++".a") (vzi++".a") - initZero t2 (v++".b") (vzi++".b") - initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZero (SMTScal sty) v _ = case sty of + let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) + initZeroZI :: SMTy a -> String -> String -> CompM () + initZeroZI SMTNil _ _ = return () + initZeroZI (SMTPair t1 t2) v vzi = do + initZeroZI t1 (v++".a") (vzi++".a") + initZeroZI t2 (v++".b") (vzi++".b") + initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi + initZeroZI (SMTScal sty) v _ = case sty of STI32 -> emit $ SAsg v (CELit "0") STI64 -> emit $ SAsg v (CELit "0l") STF32 -> emit $ SAsg v (CELit "0.0f") STF64 -> emit $ SAsg v (CELit "0.0") - let -- | Dereference an accumulation value. Sparse components encountered - -- along the way are initialised before proceeding downwards. At the - -- point where we have the projected accumulator position available, - -- the handler will be invoked with a variable name pointing to the - -- projected position. - -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler) - accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM () - accumRef _ SAPHere v _ k = k v - - accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k - accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k - - accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do - ((), stmtsInit1) <- scope $ initZero ta (v++".l") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef ta prj' (v++".l") i k - accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do - ((), stmtsInit2) <- scope $ initZero tb (v++".r") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty - accumRef tb prj' (v++".r") i k - - accumRef (SMTMaybe tj) (SAPJust prj') v i k = do - ((), stmtsInit1) <- scope $ initZero tj (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i k - - accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do - when emitChecks $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do - let a .||. b = CEBinop a "||" b - emit $ SIf (CEBinop ixcomp "<" (CELit "0") - .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ - v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ - "); " ++ - "return false;") - mempty - - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k + let -- Initialise an uninitialised accumulation value, potentially already + -- with the addend, potentially to zero depending on the nature of the + -- projection. + -- 1. If the projection indexes only through dense monoids before + -- reaching SAPHere, the thing cannot be initialised to zero with + -- only an AcIdx; it would need to model a zero after the addend, + -- which is stupid and redundant. In this case, we return Left: + -- (accumulation value) (AcIdx value) (addend value). + -- The addend is copied, not consumed. (We can't reliably _always_ + -- consume it, so it's not worth trying to do it sometimes.) + -- 2. Otherwise, a sparse monoid is found along the way, and we can + -- initalise the dense prefix of the path to zero by setting the + -- indexed-through sparse value to a sparse zero. Afterwards, the + -- main recursion can proceed further. In this case, we return + -- Right: (accumulation value) (AcIdx value) + -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) + initZeroChunk :: SMTy a -> SAcPrj p a b + -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend + (String -> String -> CompM ()) -- zero initialisation of sparse chunk + initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of + -- reached target before the first sparse constructor + (t1 , SAPHere ) -> Left $ \v _ addend -> do + incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend + emit $ SAsg v (CELit addend) + -- sparse types + (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") + (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") + -- dense types + (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do + f (v++".a") (i++".a") + initZeroZI t2 (v++".b") (i++".b") + (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do + initZeroZI t1 (v++".a") (i++".a") + f (v++".b") (i++".b") + (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do + initZeroArray n t1 v (i++".a.b") + linidxvar <- genName' "li" + emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) + f (v++".buf->xs["++linidxvar++"]") (i++".b") + where + applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i + applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i let -- Add a value (s) into an existing accumulation value (d). If a sparse - -- component of d is encountered, s is simply written there. + -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () add SMTNil _ _ = return () add (SMTPair t1 t2) d s = do @@ -1165,16 +1144,88 @@ compile' env = \case mempty shsizename <- genName' "acshsz" - emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")) + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) ivar <- genName' "i" ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) stmts1 add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + let -- | Dereference an accumulation value and add a given value to that + -- position. Sparse components encountered along the way are + -- initialised before proceeding downwards. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () + accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + + accumRef (SMTLEither ta tb) prj0 v i addend = do + let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' + SAPRight prj' -> initZeroChunk tb prj' + subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") + tagval = case prj0 of SAPLeft{} -> "1" + SAPRight{} -> "2" + ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend + SAPRight prj' -> accumRef tb prj' subv i addend + case chunkres of + Left densef -> do + ((), stmtsSet) <- scope $ densef subv i addend + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) + stmtsAdd -- TODO: emit check for consistency of tags? + Right sparsef -> do + ((), stmtsInit) <- scope $ sparsef subv i + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty + forM_ stmtsAdd emit + + accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do + case initZeroChunk tj prj' of + Left densef -> do + ((), stmtsSet1) <- scope $ densef (v++".j") i addend + ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) + stmtsAdd1 + Right sparsef -> do + ((), stmtsInit1) <- scope $ sparsef (v++".j") i + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty + accumRef tj prj' (v++".j") i addend + + accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ (zip3 [0::Int ..] + (indexTupleComponents n (i++".a.a")) + (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + let a .||. b = CEBinop a "||" b + emit $ SIf (CEBinop ixcomp "<" (CELit "0") + .||. + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + .||. + CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + "); " ++ + "return false;") + mempty + + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + + nameidx <- compileAssign "acidx" env eidx + nameval <- compileAssign "acval" env eval + nameacc <- compileAssign "acac" env eacc + emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - accumRef t prj (nameacc++".buf->ac") nameidx $ \dest -> - add (acPrjTy prj t) dest nameval + accumRef t prj (nameacc++".buf->ac") nameidx nameval emit $ SVerbatim $ "// compile EAccum end" incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval |