diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Compile.hs | |
parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 441 |
1 files changed, 275 insertions, 166 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index e2d004a..503c342 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -45,7 +45,6 @@ import qualified Prelude import Array import AST import AST.Pretty (ppSTy, ppExpr) -import qualified CHAD.Types as CHAD import Compile.Exec import Data import Interpreter.Rep @@ -230,11 +229,15 @@ genStructName = \t -> "ty_" ++ gen t where STF32 -> "f" STF64 -> "d" STBool -> "b" - gen (STAccum t) = 'C' : gen t + gen (STAccum t) = 'C' : gen (fromSMTy t) + gen (STLEither a b) = 'L' : gen a ++ gen b -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the -- types in the C translation. +-- +-- For accumulation it is important that for struct representations of monoid +-- types, the all-zero-bytes value corresponds to the zero value of that type. genStruct :: String -> STy t -> [StructDecl] genStruct name topty = case topty of STNil -> @@ -247,13 +250,17 @@ genStruct name topty = case topty of [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. + -- TODO: put shape in the main struct, not the buffer; it's constant, after all + -- TODO: no buffer if n = 0 [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" ,StructDecl name (name ++ "_buf *buf;") com] STScal _ -> [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") "" + [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" ,StructDecl name (name ++ "_buf *buf;") com] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] where com = ppSTy 0 topty @@ -278,7 +285,8 @@ genStructs ty = do STMaybe t -> genStructs t STArr _ t -> genStructs t STScal _ -> pure () - STAccum t -> genStructs (CHAD.d2 t) + STAccum t -> genStructs (fromSMTy t) + STLEither a b -> genStructs a >> genStructs b tell (BList (genStruct name ty)) @@ -450,7 +458,7 @@ serialise topty topval ptr off k = serialise a x ptr off $ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k (STEither a _, Left x) -> do - pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b}) + pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) serialise a x ptr (off + alignmentSTy topty) k (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) @@ -485,6 +493,15 @@ serialise topty topval ptr off k = STF64 -> pokeByteOff ptr off (x :: Double) >> k STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k (STAccum{}, _) -> error "Cannot serialise accumulators" + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k -- | Assumes that this is called at the correct alignment. deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -498,7 +515,7 @@ deserialise topty ptr off = return (x, y) STEither a b -> do tag <- peekByteOff @Word8 ptr off - if tag == 0 -- alignment of (a + b) is alignment of (union {a b}) + if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) STMaybe t -> do @@ -524,6 +541,13 @@ deserialise topty ptr off = STF64 -> peekByteOff @Double ptr off STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off STAccum{} -> error "Cannot serialise accumulators" + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" align :: Int -> Int -> Int align a off = (off + a - 1) `div` a * a @@ -555,7 +579,11 @@ metricsSTy (STScal sty) = case sty of STF32 -> (4, 4) STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t -metricsSTy (STAccum t) = metricsSTy t +metricsSTy (STAccum t) = metricsSTy (fromSMTy t) +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -685,6 +713,39 @@ compile' env = \case <> pure (SAsg retvar e3)))) return (CELit retvar) + ELNil _ t1 t2 -> do + name <- emitStruct (STLEither t1 t2) + return $ CEStruct name [("tag", CELit "0")] + + ELInl _ t e -> do + name <- emitStruct (STLEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("l", e1)] + + ELInr _ t e -> do + name <- emitStruct (STLEither t (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "2"), ("r", e1)] + + ELCase _ e a b c -> do + let STLEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b + (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c + ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 <> pure (SAsg retvar e2)) + (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) + (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) + (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) + return (CELit retvar) + EConstArr _ n t (Array sh vec) -> do strname <- emitStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" @@ -734,8 +795,7 @@ compile' env = \case -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) - resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) - [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname @@ -781,8 +841,7 @@ compile' env = \case -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname @@ -833,8 +892,7 @@ compile' env = \case resname <- allocArray "repl1i" Malloc "rep" (SS n) t (Just (CEBinop (CELit shszname) "*" (CELit lenname))) - ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] - ++ [CELit lenname]) + (compileArrShapeComponents n argname ++ [CELit lenname]) ivar <- genName' "i" jvar <- genName' "j" @@ -926,20 +984,20 @@ compile' env = \case zeroRefcountCheck (typeOf e1) "with" name1 emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" - mcopy <- copyForWriting (CHAD.d2 t) name1 + mcopy <- copyForWriting t name1 accname <- genName' "accum" emit $ SVarDecl False actyname accname - (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (CHAD.d2 t)))])]) + (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 resname <- genName' "acret" - emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac")) + emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" - rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t)) + rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] EAccum _ t prj eidx eval eacc -> do @@ -947,156 +1005,180 @@ compile' env = \case nameval <- compileAssign "acval" env eval -- Generate the variable manually because this one has to be non-const. - eacc' <- compile' env eacc - nameacc <- genName' "acac" - emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' - - let -- Expects a variable reference to a value of type @D2 a@. - setZero :: STy a -> String -> CompM () - setZero STNil _ = return () - setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b)) - setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b)) - setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a) - setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a)) - setZero (STScal sty) v = case sty of - STI32 -> return () -- Nil - STI64 -> return () -- Nil + -- TODO: old code: + -- eacc' <- compile' env eacc + -- nameacc <- genName' "acac" + -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' + nameacc <- compileAssign "acac" env eacc + + let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a + -- full zero array with the given zero info (for the type SMTArr n t1). + initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () + initZeroArray n t1 v vzi = do + shszname <- genName' "inacshsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) + newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) + emit $ SAsg v (CELit newarrName) + forM_ (initZeroFromMemset t1) $ \f1 -> do + ivar <- genName' "i" + ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts + + -- If something needs to be done to properly initialise this type to + -- zero after memory has already been initialised to all-zero bytes, + -- returns an action that does so. + -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) + initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) + initZeroFromMemset SMTNil = Nothing + initZeroFromMemset (SMTPair t1 t2) = + case (initZeroFromMemset t1, initZeroFromMemset t2) of + (Nothing, Nothing) -> Nothing + (mf1, mf2) -> Just $ \v vzi -> do + forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") + forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") + initZeroFromMemset SMTLEither{} = Nothing + initZeroFromMemset SMTMaybe{} = Nothing + initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi + initZeroFromMemset SMTScal{} = Nothing + + let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) + initZero :: SMTy a -> String -> String -> CompM () + initZero SMTNil _ _ = return () + initZero (SMTPair t1 t2) v vzi = do + initZero t1 (v++".a") (vzi++".a") + initZero t2 (v++".b") (vzi++".b") + initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi + initZero (SMTScal sty) v _ = case sty of + STI32 -> emit $ SAsg v (CELit "0") + STI64 -> emit $ SAsg v (CELit "0l") STF32 -> emit $ SAsg v (CELit "0.0f") STF64 -> emit $ SAsg v (CELit "0.0") - STBool -> return () -- Nil - setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported" - initD2Pair :: STy a -> STy b -> String -> CompM () - initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b)) - ((), stmts1) <- scope $ setZero a (v++".j.a") - ((), stmts2) <- scope $ setZero b (v++".j.b") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2) - mempty + let -- | Dereference an accumulation value. Sparse components encountered + -- along the way are initialised before proceeding downwards. At the + -- point where we have the projected accumulator position available, + -- the handler will be invoked with a variable name pointing to the + -- projected position. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM () + accumRef _ SAPHere v _ k = k v - initD2Either :: STy a -> STy b -> String -> Either () () -> CompM () - initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b)) - ((), stmts) <- case side of - Left () -> scope $ setZero a (v++".j.l") - Right () -> scope $ setZero b (v++".j.r") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts) - mempty + accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k + accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k - initD2Maybe :: STy a -> String -> CompM () - initD2Maybe a v = do -- Maybe (D2 a) - ((), stmts) <- scope $ setZero a (v++".j") + accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do + ((), stmtsInit1) <- scope $ initZero ta (v++".l") i emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts) - mempty + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty + accumRef ta prj' (v++".l") i k + accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do + ((), stmtsInit2) <- scope $ initZero tb (v++".r") i + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty + accumRef tb prj' (v++".r") i k - -- mind: this has to traverse the D2 of these things, and it also has to - -- initialise data structures that are still sparse in the accumulator. - let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String - accumRef _ SAPHere v _ = pure v - accumRef (STPair ta tb) (SAPFst prj') v i = do - initD2Pair ta tb v - accumRef ta prj' (v++".j.a") i - accumRef (STPair ta tb) (SAPSnd prj') v i = do - initD2Pair ta tb v - accumRef tb prj' (v++".j.b") i - accumRef (STEither ta tb) (SAPLeft prj') v i = do - initD2Either ta tb v (Left ()) - accumRef ta prj' (v++".j.l") i - accumRef (STEither ta tb) (SAPRight prj') v i = do - initD2Either ta tb v (Right ()) - accumRef tb prj' (v++".j.r") i - accumRef (STMaybe tj) (SAPJust prj') v i = do - initD2Maybe tj v - accumRef tj prj' (v++".j") i - accumRef (STArr n t') (SAPArrIdx prj' _) v i = do - (newarrName, newarrStmts) <- scope $ allocArray "accumRef" Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b")) + accumRef (SMTMaybe tj) (SAPJust prj') v i k = do + ((), stmtsInit1) <- scope $ initZero tj (v++".j") i emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) - <> newarrStmts - <> pure (SAsg (v++".j") (CELit newarrName))) - mempty + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty + accumRef tj prj' (v++".j") i k + accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ (zip3 [0::Int ..] (indexTupleComponents n (i++".a.a")) - (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) .||. - CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))) + CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ - v ++ ".j.buf" ++ - concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b") - - -- mind: this has to add the D2 of these things, and it also has to - -- initialise data structures that are still sparse in the accumulator. - let add :: STy a -> String -> String -> CompM () - add STNil _ _ = return () - add (STPair t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a") - ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (stmts1 <> stmts2) - mempty)) - add (STEither t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l") - ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r") + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k + + let -- Add a value (s) into an existing accumulation value (d). If a sparse + -- component of d is encountered, s is simply written there. + add :: SMTy a -> String -> String -> CompM () + add SMTNil _ _ = return () + add (SMTPair t1 t2) d s = do + add t1 (d++".a") (s++".a") + add t2 (d++".b") (s++".b") + add (SMTLEither t1 t2) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s + ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") + ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SAsg (d++".j.tag") (CELit (s++".j.tag"))) - <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0")) - stmts1 stmts2)) - mempty)) - add (STMaybe t1) d s = do + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + ((if emitChecks + then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) + "&&" + (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ + "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ + "return false;") + mempty) + else mempty) + -- note: s may have tag 0 + <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + stmts1 + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) + stmts2 mempty)))) + add (SMTMaybe t1) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SAsg (d++".tag") (CELit "1")) <> stmts1) - mempty)) - add (STArr n t1) d s = do + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) + add (SMTArr n t1) d s = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ [0 .. fromSNat n - 1] $ \j -> do + emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + "!=" + (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ + "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ + d ++ ".buf" ++ + concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + ", " ++ s ++ ".buf" ++ + concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + "); " ++ + "return false;") + mempty + shsizename <- genName' "acshsz" + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")) ivar <- genName' "i" - ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]") - ((), stmtsDecr) <- scope $ incrementVarAlways "accumarr" Decrement (STArr n (CHAD.d2 t1)) (s++".j") - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))) - -- TODO: emit check here for the source being either equal in shape to the destination - <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) - stmts1) - <> stmtsDecr))) - mempty - add (STScal sty) d s = case sty of - STI32 -> return () - STI64 -> return () - STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STBool -> return () - add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" + ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1 + add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - dest <- accumRef t prj (nameacc++".buf->ac") nameidx - add (acPrjTy prj t) dest nameval + accumRef t prj (nameacc++".buf->ac") nameidx $ \dest -> + add (acPrjTy prj t) dest nameval emit $ SVerbatim $ "// compile EAccum end" + incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval + return $ CEStruct (repSTy STNil) [] EError _ t s -> do @@ -1111,9 +1193,9 @@ compile' env = \case name <- emitStruct t return $ CEStruct name [] - EZero{} -> error "Compile: monoid operations should have been eliminated" - EPlus{} -> error "Compile: monoid operations should have been eliminated" - EOneHot{} -> error "Compile: monoid operations should have been eliminated" + EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EIdx1{} -> error "Compile: not implemented: EIdx1" @@ -1144,6 +1226,7 @@ data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array | ATNoop -- ^ don't do anything here | ATProj String ArrayTree -- ^ descend one field deeper | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 | ATBoth ArrayTree ArrayTree -- ^ do both these paths smartATProj :: String -> ArrayTree -> ArrayTree @@ -1154,6 +1237,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree smartATCondTag ATNoop ATNoop = ATNoop smartATCondTag t t' = ATCondTag t t' +smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree +smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop +smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 + smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree smartATBoth ATNoop t = t smartATBoth t ATNoop = t @@ -1169,6 +1256,9 @@ makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTre makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1204,6 +1294,15 @@ incrementVar' marker inc path (ATCondTag t1 t2) = do ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 +incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) + stmts2 + (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) + stmts3 + stmts1)) incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 toLinearIdx :: SNat n -> String -> String -> CExpr @@ -1257,10 +1356,12 @@ compileShapeQuery (SS n) var = -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize SZ _ = CELit "1" -compileArrShapeSize n var = - foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") - | i <- [0 .. fromSNat n - 1]] +compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeComponents :: SNat n -> String -> [CExpr] +compileArrShapeComponents n var = + [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1347,8 +1448,7 @@ compileExtremum nameBase opName operator env e = do -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname @@ -1375,47 +1475,47 @@ compileExtremum nameBase opName operator env e = do -- | If this returns Nothing, there was nothing to copy because making a simple -- value copy in C already makes it suitable to write to. -copyForWriting :: STy t -> String -> CompM (Maybe CExpr) +copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) copyForWriting topty var = case topty of - STNil -> return Nothing + SMTNil -> return Nothing - STPair a b -> do + SMTPair a b -> do e1 <- copyForWriting a (var ++ ".a") e2 <- copyForWriting b (var ++ ".b") case (e1, e2) of (Nothing, Nothing) -> return Nothing - _ -> return $ Just $ CEStruct (repSTy topty) + _ -> return $ Just $ CEStruct toptyname [("a", fromMaybe (CELit (var++".a")) e1) ,("b", fromMaybe (CELit (var++".b")) e2)] - STEither a b -> do + SMTLEither a b -> do (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") case (e1, e2) of (Nothing, Nothing) -> return Nothing _ -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name + emit $ SVarDeclUninit toptyname name emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) (stmts2 - <> pure (SAsg name (CEStruct (repSTy topty) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) return (Just (CELit name)) - STMaybe t -> do + SMTMaybe t -> do (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") case e1 of Nothing -> return Nothing Just e1' -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name + emit $ SVarDeclUninit toptyname name emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) + (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) (stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) return (Just (CELit name)) -- If there are no nested arrays, we know that a refcount of 1 means that the @@ -1423,10 +1523,10 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - STArr n t | not (hasArrays t) -> do + SMTArr n t | not (hasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" - emit $ SVarDeclUninit (repSTy (STArr n t)) name + emit $ SVarDeclUninit toptyname name when debugShapes $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" @@ -1438,11 +1538,11 @@ copyForWriting topty var = case topty of emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) (pure (SAsg name (CELit var))) (let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])]) + ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") @@ -1450,26 +1550,26 @@ copyForWriting topty var = case topty of printCExpr 0 databytes ");"]) return (Just (CELit name)) - STArr n t -> do + SMTArr n t -> do shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes name <- genName - emit $ SVarDecl False (repSTy (STArr n t)) name - (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])]) + emit $ SVarDecl False toptyname name + (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain dstvar <- genName' "cpydst" - emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) srcvar <- genName' "cpysrc" - emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) ivar <- genName' "i" @@ -1484,9 +1584,10 @@ copyForWriting topty var = case topty of return (Just (CELit name)) - STScal _ -> return Nothing + SMTScal _ -> return Nothing - STAccum _ -> error "Compile: Nested accumulators not supported" + where + toptyname = repSTy (fromSMTy topty) zeroRefcountCheck :: STy t -> String -> String -> CompM () zeroRefcountCheck toptyp opname topvar = @@ -1521,6 +1622,14 @@ zeroRefcountCheck toptyp opname topvar = return (BList [s1, s2, s3]) go STScal{} _ = empty go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) combine (MaybeT a) (MaybeT b) = MaybeT $ do |