aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
commit42176d4a8a0fe7954f17da5c0506721695aa477f (patch)
tree8a29e847faa613e9becf1bccdcaad010187e639b
parent7729c45a325fe653421d654ed4c28b040585fce9 (diff)
WIP fold: everything but Compile (slow, but should be sound)
-rw-r--r--src/AST.hs16
-rw-r--r--src/AST/Count.hs68
-rw-r--r--src/AST/Pretty.hs34
-rw-r--r--src/AST/SplitLets.hs36
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/Analysis/Identity.hs26
-rw-r--r--src/CHAD.hs20
-rw-r--r--src/ForwardAD/DualNumbers.hs4
-rw-r--r--src/Interpreter.hs44
-rw-r--r--src/Simplify.hs4
10 files changed, 222 insertions, 32 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 275abcd..2d4fd91 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -70,15 +70,16 @@ data Expr x env t where
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) t1)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1)
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1)
-> Expr x env (TPair (TArr n t1) -- normal primal fold output
(TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores)
- -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage.
- EFold1InnerD2 :: x (TArr (S n) t2) -> Commutative
+ -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. Perhaps a hack to reduce the impact is to store ZeroInfos only?
+ EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative
-> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0
-- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true
-> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
-> Expr x env (TArr (S n) t1) -- primal input array
+ -> Expr x env (ZeroInfo t2)
-> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1
-> Expr x env (TArr n t2) -- incoming cotangent
-> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array
@@ -231,6 +232,9 @@ typeOf = \case
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tape <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) (STPair t1 tape))
+ EFold1InnerD2 _ _ t2 _ _ _ _ e5 | STArr n _ <- typeOf e5 -> STPair (fromSMTy t2) (STArr (SS n) (fromSMTy t2))
+
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
@@ -277,6 +281,8 @@ extOf = \case
EReplicate1Inner x _ _ -> x
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
+ EFold1InnerD1 x _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ _ _ _ -> x
EConst x _ _ -> x
EIdx0 x _ -> x
EIdx1 x _ _ -> x
@@ -323,6 +329,8 @@ travExt f = \case
EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> pure t2 <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
@@ -382,6 +390,8 @@ subst' f w = \case
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 x cm t2 (subst' (sinkF (sinkF (sinkF (sinkF f)))) (WCopy (WCopy (WCopy (WCopy w)))) a) (subst' f w b) (subst' f w c) (subst' f w d) (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index bec5a9d..cb363a3 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -356,7 +356,14 @@ pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
fullOccEnv SNil = OccEnd
fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
--- * s: how much of the result is required
+-- In one traversal, count occurrences of variables and determine what parts of
+-- expressions are actually used. These two results are computed independently:
+-- even if (almost) nothing of a particular term is actually used, variable
+-- references in that term still count as usual.
+--
+-- In @occCountX s t k@:
+-- * s: how much of the result of this term is required
+-- * t: the term to analyse
-- * k: is passed the actual environment usage of this expression, including
-- occurrence counts. The callback reconstructs a new expression in an
-- updated "response" environment. The response must be at least as large as
@@ -434,8 +441,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env1' $ \env1 s1 ->
occEnvPop' env2' $ \env2 s2 ->
occCountX (SsEither s1 s2) e $ \env0 mke ->
- withSome (Some env1 <||> Some env2) $ \env12 ->
- withSome (Some env12 <> Some env0) $ \env ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
k env $ \env' ->
ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
ENothing _ t ->
@@ -459,8 +465,7 @@ occCountX initialS topexpr k = case topexpr of
occCountX s b $ \env2' mkb ->
occEnvPop' env2' $ \env2 s2 ->
occCountX (SsMaybe s2) e $ \env0 mke ->
- withSome (Some env1 <||> Some env2) $ \env12 ->
- withSome (Some env12 <> Some env0) $ \env ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
k env $ \env' ->
EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
ELNil _ t1 t2 ->
@@ -497,9 +502,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env2' $ \env2 s1 ->
occEnvPop' env3' $ \env3 s2 ->
occCountX (SsLEither s1 s2) e $ \env0 mke ->
- withSome (Some env1 <||> Some env2) $ \env12 ->
- withSome (Some env12 <||> Some env3) $ \env123 ->
- withSome (Some env123 <> Some env0) $ \env ->
+ withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env ->
k env $ \env' ->
ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
@@ -550,16 +553,11 @@ occCountX initialS topexpr k = case topexpr of
occCountX (SsArr sElt) c $ \env3 mkc ->
withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
- let expr = EFold1Inner ext commut
- (projectSmallerSubstruc SsFull sElt $
- mka (OccPush (OccPush env' () sElt) () sElt))
- (mkb env') (mkc env') in
- case s of
- SsNone -> use expr $ ENil ext
- SsArr s' -> projectSmallerSubstruc (SsArr sElt) (SsArr s') expr
- SsFull -> case testEquality sElt SsFull of
- Just Refl -> expr
- Nothing -> error "unreachable"
+ projectSmallerSubstruc (SsArr sElt) s $
+ EFold1Inner ext commut
+ (projectSmallerSubstruc SsFull sElt $
+ mka (OccPush (OccPush env' () sElt) () sElt))
+ (mkb env') (mkc env')
ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
@@ -594,6 +592,40 @@ occCountX initialS topexpr k = case topexpr of
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+ EFold1InnerD1 _ cm a b c ->
+ -- TODO: currently maximally pessimistic on usage here; in particular,
+ -- usage tracking of the 'tape' stores can be improved
+ occCountX SsFull a $ \env1_2' mka ->
+ withSome (scaleMany (Some env1_2')) $ \env1_2 ->
+ occEnvPop' env1_2 $ \env1_1 _ ->
+ occEnvPop' env1_1 $ \env1 _ ->
+ occCountX SsFull b $ \env2 mkb ->
+ occCountX SsFull c $ \env3 mkc ->
+ withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ (mkb env') (mkc env')
+
+ EFold1InnerD2 _ cm t2 ef ep ezi ebog ed ->
+ -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage
+ occCountX SsFull ef $ \env1_4' mkef ->
+ withSome (scaleMany (Some env1_4')) $ \env1_4 ->
+ occEnvPop' env1_4 $ \env1_3 _ ->
+ occEnvPop' env1_3 $ \env1_2 _ ->
+ occEnvPop' env1_2 $ \env1_1 _ ->
+ occEnvPop' env1_1 $ \env1 sTape ->
+ occCountX SsFull ep $ \env2 mkep ->
+ occCountX SsFull ezi $ \env3 mkezi ->
+ occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog ->
+ occCountX SsFull ed $ \env5 mked ->
+ withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD2 ext cm t2
+ (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull))
+ (mkep env') (mkezi env') (mkebog env') (mked env')
+
EConst _ t x ->
k OccEnd $ \_ ->
case s of
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 9018602..afa62c6 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -72,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex
_ -> return "_"
| otherwise = genName' prefix
+-- TODO: let this return a type-tagged thing so that name environments are more typed than Const
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
@@ -209,8 +210,7 @@ ppExpr' d val expr = case expr of
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
- let opname = case cm of Commut -> "fold1i(C)"
- Noncommut -> "fold1i"
+ let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
@@ -235,6 +235,32 @@ ppExpr' d val expr = case expr of
e' <- ppExpr' 11 val e
return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+ EFold1InnerD1 _ cm a b c -> do
+ name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
+ name2 <- genNameIfUsedIn (typeOf b) IZ a
+ a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
+ let opname = "fold1iD1" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+
+ EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do
+ let STArr _ (STPair t1 ttape) = typeOf ebog
+ name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef
+ name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef
+ name3 <- genNameIfUsedIn t1 (IS IZ) ef
+ name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef
+ ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef
+ ep' <- ppExpr' 11 val ep
+ ezi' <- ppExpr' 11 val ezi
+ ebog' <- ppExpr' 11 val ebog
+ ed' <- ppExpr' 11 val ed
+ let opname = "fold1iD2" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr)
+ [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed']
+
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -386,6 +412,10 @@ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
ppSparse (SMTScal _) SpScal = "."
+ppCommut :: Commutative -> String
+ppCommut Commut = "(C)"
+ppCommut Noncommut = ""
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 82ec1d6..f75e795 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
splitLets' = \sub -> \case
EVar _ t i -> sub t i WId
- ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
ECase x e a b ->
let STEither t1 t2 = typeOf e
in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
@@ -35,6 +35,13 @@ splitLets' = \sub -> \case
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD1 x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD2 x cm t2 a b c d e ->
+ let STArr _ t1 = typeOf b
+ STArr _ (STPair _ ttape) = typeOf d
+ in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
@@ -98,6 +105,33 @@ splitLets' = \sub -> \case
t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
body
+ -- TODO: abstract this to splitN lol wtf
+ split4 :: forall bind1 bind2 bind3 bind4 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
+ split4 sub tbind1 tbind2 tbind3 tbind4 body =
+ let (ptrs1, bs1') = split @env' tbind1
+ (ptrs2, bs2') = split @(bind1 : env') tbind2
+ (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
+ (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4
+ bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1')
+ bs2 = fst (weakenBindingsE (WSink .> WSink) bs2')
+ bs3 = fst (weakenBindingsE WSink bs3')
+ b1 = bindingsBinds bs1
+ b2 = bindingsBinds bs2
+ b3 = bindingsBinds bs3
+ b4 = bindingsBinds bs4
+ in letBinds bs1 $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1))
+ _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink))
+ _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink))
+ _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env')))
+ t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w)))))))))
+ body
+
type family Split t where
Split (TPair a b) = SplitRec (TPair a b)
Split _ = '[]
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 48dd709..555f0ec 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -44,6 +44,8 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index b54946b..2fa156d 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -244,6 +244,32 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> pure sh
pure (res, EMinimum1Inner res e1')
+ EFold1InnerD1 _ cm e1 e2 e3 -> do
+ let t1 = typeOf e2
+ x1 <- genIds t1
+ x2 <- genIds t1
+ (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
+ (_, e2') <- idana env e2
+ (v3, e3') <- idana env e3
+ let VIArr _ sh'@(_ :< sh) = v3
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh')
+ pure (res, EFold1InnerD1 res cm e1' e2' e3')
+
+ EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do
+ let STArr _ (STPair t1 ttape) = typeOf ebog
+ x1 <- genIds (fromSMTy t2)
+ x2 <- genIds t1
+ x3 <- genIds t1
+ x4 <- genIds ttape
+ (_, e1') <- idana (x1 `SCons` x2 `SCons` x3 `SCons` x4 `SCons` env) ef
+ (v2, e2') <- idana env ep
+ (_, e3') <- idana env ezi
+ (_, e4') <- idana env ebog
+ (_, e5') <- idana env ed
+ let VIArr _ sh = v2
+ res <- VIPair <$> genIds (fromSMTy t2) <*> (VIArr <$> genId <*> pure sh)
+ pure (res, EFold1InnerD2 res cm t2 e1' e2' e3' e4' e5')
+
EConst _ t val -> do
res <- VIScal <$> genId
pure (res, EConst res t val)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index ec719e8..25d26a6 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1136,6 +1136,7 @@ drev des accumMap sd = \case
library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
+ &. #pzi (auto1 @(ZeroInfo (D2 elt)))
&. #primal (primalTy `SCons` SNil)
&. #darr (auto1 @(TArr n sdElt))
&. #d (auto1 @(D2 elt))
@@ -1157,23 +1158,25 @@ drev des accumMap sd = \case
subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
Ret (bconcat bindsx₀a mergePrimalBindings'
`bpush` weakenExpr wOverPrimalBindings ex₀1
- `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1
+ `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
`bpush` EFold1InnerD1 ext commut
- (letBinds (fst (weakenBindingsE (autoWeak library
+ (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
+ letBinds (fst (weakenBindingsE (autoWeak library
(#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ layout)
ef0)) $
EPair ext
(weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ (#fbinds :++: layout))
ef1)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))))
- (EVar ext (d1 eltty) (IS IZ))
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout))))
+ (EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
- (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
(EFst ext (EVar ext primalTy IZ))
subx₀af
- (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
(uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
makeAccumulators (autoWeak library #propr layout1) envPro $
@@ -1187,6 +1190,7 @@ drev des accumMap sd = \case
ef2) $
EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))
(ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))
(ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
(EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index c98b6c0..467b895 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -194,9 +194,13 @@ dfwdDN = \case
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
+
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
+ err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program"
deriv_extremum :: ScalIsNumeric t ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index ffc2929..db66540 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,6 +21,8 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict (runStateT, get, put)
import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
@@ -28,6 +30,7 @@ import Data.Functor.Identity
import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
+import Data.Tuple (swap)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
@@ -143,6 +146,39 @@ interpret'Rec env = \case
sh `ShCons` n = arrayShape arr
numericIsNum t $ return $
arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EFold1InnerD1 _ _ a b c -> do
+ let t = typeOf b
+ let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ -- TODO: this is very inefficient, even for an interpreter; with mutable
+ -- arrays this can be a lot better with no lists
+ res <- arrayGenerateM sh $ \idx -> do
+ (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ return (y, arrayFromList (ShNil `ShCons` n) stores)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do
+ let STArr _ (STPair t1 ttape) = typeOf ebog
+ let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef
+ parr <- interpret' env ep
+ zi <- interpret' env ezi
+ bog <- interpret' env ebog
+ arrctg <- interpret' env ed
+ let sh `ShCons` n = arrayShape parr
+ res <- arrayGenerateM sh $ \idx -> do
+ let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
+ loop i !ctg !inpctgs = do
+ let (prefix, tape) = arrayIndex bog (idx `IxCons` i)
+ (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg
+ loop (i - 1) ctg1 (ctg2 : inpctgs)
+ (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
+ return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
+ return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res)
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
@@ -411,3 +447,11 @@ ixUncons (IxCons idx i) = (idx, i)
shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)
+
+mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
+mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f'
+ where f' x = do
+ s <- get
+ (s', y) <- lift $ f s x
+ put s'
+ return y
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 51870d4..c1f92f1 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -314,6 +314,8 @@ simplify'Rec = \case
EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
+ EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
+ EFold1InnerD2 _ cm t2 a b c d e -> [simprec| EFold1InnerD2 ext cm t2 *a *b *c *d *e |]
EConst _ t v -> pure $ EConst ext t v
EIdx0 _ e -> [simprec| EIdx0 ext *e |]
EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
@@ -367,6 +369,8 @@ hasAdds = \case
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
EMaximum1Inner _ e -> hasAdds e
EMinimum1Inner _ e -> hasAdds e
+ EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ EFold1InnerD2 _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e