aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 20:09:10 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 20:09:10 +0100
commit654b13d0de961788ed600e8eeb6c9fbbd736439e (patch)
tree4821c1082aa2fcd04ea6ecb0e07f8d3bdbe643ed
parent4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f (diff)
Compile foldd1 and foldd2HEADmaster
-rw-r--r--src/Compile.hs114
-rw-r--r--src/Language.hs8
-rw-r--r--src/Language/AST.hs12
-rw-r--r--test/Main.hs9
4 files changed, 141 insertions, 2 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 4e81c6a..f2063ee 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -34,6 +34,7 @@ import Foreign
import GHC.Exts (int2Word#, addr2Int#)
import GHC.Num (integerFromWord#)
import GHC.Ptr (Ptr(..))
+import GHC.Stack (HasCallStack)
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
import System.IO.Error (mkIOError, userErrorType)
@@ -939,6 +940,117 @@ compile' env = \case
[("buf", CEProj (CELit arrname) "buf")
,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+ EFold1InnerD1 _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+ STPair _ bty = typeOf efun
+
+ x0name <- compileAssign "foldd1x0" env ex0
+ arrname <- compileAssign "foldd1arr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
+ storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "tot"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
+ arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd1x0" Decrement t x0name
+ incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
+
+ strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
+ return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
+
+ EFold1InnerD2 _ commut efun estores ectg -> do
+ let STArr n t2 = typeOf ectg
+ STArr _ bty = typeOf estores
+
+ storesname <- compileAssign "foldd2stores" env estores
+ ctgname <- compileAssign "foldd2ctg" env ectg
+
+ zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
+ outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "acc"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
+ storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
+ ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
+ ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
+ <> ctgeltIncrStmts
+ -- we need to loop in reverse here, but we let jvar run in the
+ -- forward direction so that we can use SLoop. Note jvar is
+ -- reversed in eltidx above
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the accumulator
+ -- and the stores element. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the stores element.
+ storeseltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
+ incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
+
+ strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
+ return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
+
EConst _ t x -> return $ CELit $ compileScal True t x
EIdx0 _ e -> do
@@ -1311,7 +1423,7 @@ data AllocMethod = Malloc | Calloc
deriving (Show)
-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
allocArray marker method nameBase rank eltty mshsz shape = do
when (length shape /= fromSNat rank) $
error "allocArray: shape does not match rank"
diff --git a/src/Language.hs b/src/Language.hs
index d3c38d6..31b4b87 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -133,6 +133,14 @@ minimum1i e = NEMinimum1Inner e
reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
reshape = NEReshape
+fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3
+
+fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
+ -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
+
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 325817d..c9d05c9 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -58,6 +58,15 @@ data NExpr env t where
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+ NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b)
+ -> NExpr env t1
+ -> NExpr env (TArr (S n) t1)
+ -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+ NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
+ -> NExpr env (TArr (S n) b)
+ -> NExpr env (TArr n t2)
+ -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
@@ -208,6 +217,9 @@ fromNamedExpr val = \case
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEReshape n a b -> EReshape ext n (go a) (go b)
+ NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
NEIdx1 a b -> EIdx1 ext (go a) (go b)
diff --git a/test/Main.hs b/test/Main.hs
index 2acc9f8..4bc9082 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -396,7 +396,7 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
+ annotate (ppExpr env dtermSChadSUSP)
diff outChad0 closeIsh outPrimal
diff outChadS closeIsh outPrimal
diff outChadSUS closeIsh outPrimal
@@ -562,6 +562,13 @@ tests_Compile = testGroup "Compile"
nil) $
let_ #_ (accum SAPHere nil #x #ac) $
nil
+
+ ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $
+ fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a
+
+ ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $
+ let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $
+ fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d
]
tests_AD :: TestTree