summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 23:47:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 23:47:27 +0100
commit16a836d078caefc3526031c084e2527cba0da3a8 (patch)
tree7ca6a28063e259823750ee37bd4bebae27c5408d /src/Compile.hs
parentf1e867838db63da71fea660740c23ab276a43a6c (diff)
Fix various issues in Compile (still broken)
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs196
1 files changed, 148 insertions, 48 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 5501746..00b90e3 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -240,7 +240,8 @@ genStruct name topty = case topty of
STScal _ ->
[]
STAccum t ->
- [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com]
+ [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") ""
+ ,StructDecl name (name ++ "_buf *buf;") com]
where
com = ppSTy 0 topty
@@ -637,10 +638,8 @@ compile' env = \case
EBuild _ n esh efun -> do
shname <- compileAssign "sh" env esh
- shsizename <- genName' "shsz"
- emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname)
- arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (indexTupleComponents n shname)
+ arrname <- allocArray Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
idxargname <- genName' "ix"
(funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
@@ -671,7 +670,7 @@ compile' env = \case
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
- resname <- allocArray "foldres" n t (CELit shszname)
+ resname <- allocArray Malloc "foldres" n t (Just (CELit shszname))
[CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
lenname <- genName' "n"
@@ -715,7 +714,7 @@ compile' env = \case
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray "sumres" n t (CELit shszname)
+ resname <- allocArray Malloc "sumres" n t (Just (CELit shszname))
[CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
lenname <- genName' "n"
@@ -757,8 +756,8 @@ compile' env = \case
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray "rep" (SS n) t
- (CEBinop (CELit shszname) "*" (CELit lenname))
+ resname <- allocArray Malloc "rep" (SS n) t
+ (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
++ [CELit lenname])
@@ -850,28 +849,98 @@ compile' env = \case
zeroRefcountCheck (typeOf e1) "with" name1
+ emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
mcopy <- copyForWriting (CHAD.d2 t) name1
accname <- genName' "accum"
- emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])
+ emit $ SVarDecl False actyname accname
+ (CEStruct actyname [("buf", CECall "malloc" [CELit (show (sizeofSTy (CHAD.d2 t)))])])
+ emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
+ emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
e2' <- compile' (Const accname `SCons` env) e2
+ resname <- genName' "acret"
+ emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac"))
+ emit $ SVerbatim $ "free(" ++ accname ++ ".buf);"
+
rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))
- return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")]
+ return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
EAccum _ t prj eidx eval eacc -> do
nameidx <- compileAssign "acidx" env eidx
nameval <- compileAssign "acval" env eval
- nameacc <- compileAssign "acac" env eacc
+ -- Generate the variable manually because this one has to be non-const.
+ eacc' <- compile' env eacc
+ nameacc <- genName' "acac"
+ emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
+
+ let -- Expects a variable reference to a value of type @D2 a@.
+ setZero :: STy a -> String -> CompM ()
+ setZero STNil _ = return ()
+ setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b))
+ setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b))
+ setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a)
+ setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a))
+ setZero (STScal sty) v = case sty of
+ STI32 -> return () -- Nil
+ STI64 -> return () -- Nil
+ STF32 -> emit $ SAsg v (CELit "0.0f")
+ STF64 -> emit $ SAsg v (CELit "0.0")
+ STBool -> return () -- Nil
+ setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported"
+
+ initD2Pair :: STy a -> STy b -> String -> CompM ()
+ initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b))
+ ((), stmts1) <- scope $ setZero a (v++".j.a")
+ ((), stmts2) <- scope $ setZero b (v++".j.b")
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2)
+ mempty
+
+ initD2Either :: STy a -> STy b -> String -> Either () () -> CompM ()
+ initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b))
+ ((), stmts) <- case side of
+ Left () -> scope $ setZero a (v++".j.l")
+ Right () -> scope $ setZero b (v++".j.r")
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
+ mempty
+
+ initD2Maybe :: STy a -> String -> CompM ()
+ initD2Maybe a v = do -- Maybe (D2 a)
+ ((), stmts) <- scope $ setZero a (v++".j")
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
+ mempty
+
+ -- mind: this has to traverse the D2 of these things, and it also has to
+ -- initialise data structures that are still sparse in the accumulator.
let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String
accumRef _ SAPHere v _ = pure v
- accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i
- accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i
- accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i
- accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i
- accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i
+ accumRef (STPair ta tb) (SAPFst prj') v i = do
+ initD2Pair ta tb v
+ accumRef ta prj' (v++".j.a") i
+ accumRef (STPair ta tb) (SAPSnd prj') v i = do
+ initD2Pair ta tb v
+ accumRef tb prj' (v++".j.b") i
+ accumRef (STEither ta tb) (SAPLeft prj') v i = do
+ initD2Either ta tb v (Left ())
+ accumRef ta prj' (v++".j.l") i
+ accumRef (STEither ta tb) (SAPRight prj') v i = do
+ initD2Either ta tb v (Right ())
+ accumRef tb prj' (v++".j.r") i
+ accumRef (STMaybe tj) (SAPJust prj') v i = do
+ initD2Maybe tj v
+ accumRef tj prj' (v++".j") i
accumRef (STArr n t') (SAPArrIdx prj' _) v i = do
+ (newarrName, newarrStmts) <- scope $ allocArray Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b"))
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1"))
+ <> newarrStmts
+ <> pure (SAsg (v++".j") (CELit newarrName)))
+ mempty
+
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
forM_ (zip3 [0::Int ..]
@@ -880,58 +949,77 @@ compile' env = \case
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))
.||.
- CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
+ CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++
"arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
+ v ++ ".j.buf" ++
+ concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++
"); " ++
"abort();")
mempty
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b")
+ accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b")
+ -- mind: this has to add the D2 of these things, and it also has to
+ -- initialise data structures that are still sparse in the accumulator.
let add :: STy a -> String -> String -> CompM ()
add STNil _ _ = return ()
add (STPair t1 t2) d s = do
- add t1 (d++".a") (s++".a")
- add t2 (d++".b") (s++".b")
+ ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a")
+ ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s)))
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ (stmts1 <> stmts2)
+ mempty))
add (STEither t1 t2) d s = do
- ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
- ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
- emit $ SAsg (d++".tag") (CELit (s++".tag"))
- emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0"))
- stmts1 stmts2
+ ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l")
+ ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s)))
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ (pure (SAsg (d++".j.tag") (CELit (s++".j.tag")))
+ <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0"))
+ stmts1 stmts2))
+ mempty))
add (STMaybe t1) d s = do
((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
- emit $ SAsg (d++".tag") (CELit (s++".tag"))
- emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- stmts1 mempty
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s)))
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ (pure (SAsg (d++".tag") (CELit "1")) <> stmts1)
+ mempty))
add (STArr n t1) d s = do
shsizename <- genName' "acshsz"
- emit $ SVarDecl True (repSTy tIx) shsizename (compileShapeSize n (s++".a.b"))
ivar <- genName' "i"
- -- TODO: emit check here for the source being either empty or equal in shape to the destination
- ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $
- stmts1
+ ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]")
+ emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s)))
+ (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")))
+ -- TODO: emit check here for the source being either equal in shape to the destination
+ <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+ stmts1))))
+ mempty
add (STScal sty) d s = case sty of
- STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ STI32 -> return ()
+ STI64 -> return ()
STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STBool -> error "Compile: accumulator add on booleans"
+ STBool -> return ()
add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
- dest <- accumRef t prj (nameacc++".ac") nameidx
- add (typeOf eval) dest nameval
+ emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
+ dest <- accumRef t prj (nameacc++".buf->ac") nameidx
+ add (acPrjTy prj t) dest nameval
incrementVarAlways Decrement (typeOf eval) nameval
+ emit $ SVerbatim $ "// compile EAccum end"
return $ CEStruct (repSTy STNil) []
@@ -998,7 +1086,7 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
-makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t)
+makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
@@ -1054,18 +1142,26 @@ toLinearIdx (SS n) arrvar idxvar =
-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
-- _
+data AllocMethod = Malloc | Calloc
+ deriving (Show)
+
-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String
-allocArray nameBase rank eltty shsz shape = do
+allocArray :: AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray method nameBase rank eltty mshsz shape = do
when (length shape /= fromSNat rank) $
error "allocArray: shape does not match rank"
let arrty = STArr rank eltty
strname <- emitStruct arrty
arrname <- genName' nameBase
+ shsz <- case mshsz of
+ Just e -> return e
+ Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
+ let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
+ "+"
+ (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
emit $ SVarDecl True strname arrname $ CEStruct strname
- [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8)))
- "+"
- (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])]
+ [("buf", case method of Malloc -> CECall "malloc" [nbytesExpr]
+ Calloc -> CECall "calloc" [nbytesExpr, CELit "1"])]
forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
@@ -1175,7 +1271,7 @@ compileExtremum nameBase opName operator env e = do
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray (nameBase ++ "res") n t (CELit shszname)
+ resname <- allocArray Malloc (nameBase ++ "res") n t (Just (CELit shszname))
[CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
lenname <- genName' "n"
@@ -1369,3 +1465,7 @@ showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "
-- | Type-restricted.
(^) :: Num a => a -> Int -> a
(^) = (Prelude.^)
+
+foldl0' :: (a -> a -> a) -> a -> [a] -> a
+foldl0' _ x [] = x
+foldl0' f _ l = foldl1' f l