aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-28 11:56:40 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-28 11:56:40 +0100
commit955af83f664639701fdbee54718186e07b31d42f (patch)
tree30353d77c69b1dfdaf43797942dbf6e412a49450 /src
parent765b80616583322226284266605ab3a916da01db (diff)
Better fold D{1,2} primitives
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs30
-rw-r--r--src/AST/Count.hs90
-rw-r--r--src/AST/Pretty.hs22
-rw-r--r--src/AST/SplitLets.hs16
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/Analysis/Identity.hs27
-rw-r--r--src/CHAD.hs35
-rw-r--r--src/Interpreter.hs22
-rw-r--r--src/Simplify.hs4
9 files changed, 139 insertions, 109 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 2d4fd91..f7b63cf 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -70,17 +70,19 @@ data Expr x env t where
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1)
+ -- MapAccum-like (is it real mapaccum? If so, rename)
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
+ -> Expr x (t1 : t1 : env) (TPair t1 b)
+ -> Expr x env t1
+ -> Expr x env (TArr (S n) t1)
-> Expr x env (TPair (TArr n t1) -- normal primal fold output
- (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores)
- -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. Perhaps a hack to reduce the impact is to store ZeroInfos only?
+ (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
+ -- Reverse derivative of Efold1Inner.
EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative
- -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0
- -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true
- -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env (TArr (S n) t1) -- primal input array
- -> Expr x env (ZeroInfo t2)
- -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1
+ -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr x env t2 -- zero
+ -> Expr x (t2 : t2 : env) t2 -- plus
+ -> Expr x env (TArr (S n) b) -- extra data passed to function
-> Expr x env (TArr n t2) -- incoming cotangent
-> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array
@@ -232,8 +234,8 @@ typeOf = \case
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tape <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) (STPair t1 tape))
- EFold1InnerD2 _ _ t2 _ _ _ _ e5 | STArr n _ <- typeOf e5 -> STPair (fromSMTy t2) (STArr (SS n) (fromSMTy t2))
+ EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
+ EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2)
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -282,7 +284,7 @@ extOf = \case
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
EFold1InnerD1 x _ _ _ _ -> x
- EFold1InnerD2 x _ _ _ _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ _ _ -> x
EConst x _ _ -> x
EIdx0 x _ -> x
EIdx1 x _ _ -> x
@@ -330,7 +332,7 @@ travExt f = \case
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> pure t2 <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
+ EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
@@ -391,7 +393,7 @@ subst' f w = \case
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
- EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 x cm t2 (subst' (sinkF (sinkF (sinkF (sinkF f)))) (WCopy (WCopy (WCopy (WCopy w)))) a) (subst' f w b) (subst' f w c) (subst' f w d) (subst' f w e)
+ EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index d5afb5e..229661f 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -598,53 +598,65 @@ occCountX initialS topexpr k = case topexpr of
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
- EFold1InnerD1{} ->
- -- TODO: currently somewhat pessimistic on usage here
+ EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
- SsPair' _ (SsArr' (SsPair' _ sTape)) ->
- go topexpr sTape (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull sTape ))) s)
- -- the primed patsyns also consume SsFull, so if the above doesn't match,
- -- something along the way is SsNone
- _ ->
- go topexpr SsNone (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull SsNone))) s)
- where
- go :: Expr x env (TPair (TArr n t1) (TArr (S n) (TPair t1 tape)))
- -> Substruc tape tape'
- -> (forall env'. Ex env' (TPair (TArr n t1) (TArr (S n) (TPair t1 tape'))) -> Ex env' t')
- -> r
- go (EFold1InnerD1 _ cm a b c) sTape project =
- occCountX (SsPair SsFull sTape) a $ \env1_2' mka ->
- withSome (scaleMany (Some env1_2')) $ \env1_2 ->
- occEnvPop' env1_2 $ \env1_1 _ ->
- occEnvPop' env1_1 $ \env1 _ ->
- occCountX SsFull b $ \env2 mkb ->
- occCountX SsFull c $ \env3 mkc ->
- withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ -- If nothing is necessary, we can execute a fold and then proceed to ignore it
+ SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex
+ -- If we don't need the stores, still a fold suffices
+ SsPair' sP SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext)
+ -- If for whatever reason the additional stores themselves are
+ -- unnecessary but the shape of the array is, then oblige
+ SsPair' sP (SsArr' SsNone) ->
+ let STArr sn _ = typeOf e3
+ foldex =
+ elet (mapExt (\_ -> ext) e3) $
+ EPair ext
+ (EShape ext (evar IZ))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (mapExt (\_ -> ext) (weakenExpr WSink e2))
+ (evar IZ))
+ in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
+ k env1 $ \env' ->
+ eunPair (mkfoldex env') $ \_ eshape earr ->
+ EPair ext earr (EBuild ext sn eshape (ENil ext))
+ -- If at least some of the additional stores are required, we need to keep this a mapAccum
+ SsPair' _ (SsArr' sB) ->
+ -- TODO: propagate usage of primals
+ occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' _ ->
+ occCountX SsFull e2 $ \env2 mkb ->
+ occCountX SsFull e3 $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
- -- projectSmallerSubstruc SsFull s $
- project $
+ projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
(mkb env') (mkc env')
- go _ _ _ = error "impossible"
- EFold1InnerD2 _ cm t2 ef ep ezi ebog ed ->
- -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage
- occCountX SsFull ef $ \env1_4' mkef ->
- withSome (scaleMany (Some env1_4')) $ \env1_4 ->
- occEnvPop' env1_4 $ \env1_3 _ ->
- occEnvPop' env1_3 $ \env1_2 _ ->
- occEnvPop' env1_2 $ \env1_1 _ ->
- occEnvPop' env1_1 $ \env1 sTape ->
- occCountX SsFull ep $ \env2 mkep ->
- occCountX SsFull ezi $ \env3 mkezi ->
- occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog ->
+ EFold1InnerD2 _ cm ef ez eplus ebog ed ->
+ -- TODO: propagate usage of duals
+ occCountX SsFull ef $ \env1_2' mkef ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' sB ->
+ occCountX SsFull ez $ \env2' mkez ->
+ occCountX SsFull eplus $ \env3_2' mkeplus ->
+ occEnvPop' env3_2' $ \env3_1' _ ->
+ occEnvPop' env3_1' $ \env3' _ ->
+ occCountX (SsArr sB) ebog $ \env4 mkebog ->
occCountX SsFull ed $ \env5 mked ->
- withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env ->
+ withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env ->
k env $ \env' ->
projectSmallerSubstruc SsFull s $
- EFold1InnerD2 ext cm t2
- (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull))
- (mkep env') (mkezi env') (mkebog env') (mked env')
+ EFold1InnerD2 ext cm
+ (mkef (OccPush (OccPush env' () sB) () SsFull))
+ (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull))
+ (mkebog env') (mked env')
EConst _ t x ->
k OccEnd $ \_ ->
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index afa62c6..587328d 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -245,21 +245,23 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
- EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do
- let STArr _ (STPair t1 ttape) = typeOf ebog
- name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef
- name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef
- name3 <- genNameIfUsedIn t1 (IS IZ) ef
- name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef
- ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef
- ep' <- ppExpr' 11 val ep
- ezi' <- ppExpr' 11 val ezi
+ EFold1InnerD2 _ cm ef ez eplus ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ t2 = typeOf ez
+ namef1 <- genNameIfUsedIn tB (IS IZ) ef
+ namef2 <- genNameIfUsedIn t2 IZ ef
+ ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
+ ez' <- ppExpr' 11 val ez
+ namep1 <- genNameIfUsedIn t2 (IS IZ) eplus
+ namep2 <- genNameIfUsedIn t2 IZ eplus
+ eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus
ebog' <- ppExpr' 11 val ebog
ed' <- ppExpr' 11 val ed
let opname = "fold1iD2" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr)
- [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed']
+ [ppLam [ppString namef1, ppString namef2] ef', ez'
+ ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed']
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index f75e795..6034084 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -38,10 +38,10 @@ splitLets' = \sub -> \case
EFold1InnerD1 x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
- EFold1InnerD2 x cm t2 a b c d e ->
- let STArr _ t1 = typeOf b
- STArr _ (STPair _ ttape) = typeOf d
- in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e)
+ EFold1InnerD2 x cm a b c d e ->
+ let t2 = typeOf b
+ STArr _ tB = typeOf d
+ in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
@@ -106,10 +106,10 @@ splitLets' = \sub -> \case
body
-- TODO: abstract this to splitN lol wtf
- split4 :: forall bind1 bind2 bind3 bind4 env' env t.
- (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
- -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
- split4 sub tbind1 tbind2 tbind3 tbind4 body =
+ _split4 :: forall bind1 bind2 bind3 bind4 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
+ _split4 sub tbind1 tbind2 tbind3 tbind4 body =
let (ptrs1, bs1') = split @env' tbind1
(ptrs2, bs2') = split @(bind1 : env') tbind2
(ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 555f0ec..6904715 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -45,7 +45,7 @@ unMonoid = \case
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
+ EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 2fa156d..9dc8811 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -255,20 +255,21 @@ idana env expr = case expr of
res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh')
pure (res, EFold1InnerD1 res cm e1' e2' e3')
- EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do
- let STArr _ (STPair t1 ttape) = typeOf ebog
- x1 <- genIds (fromSMTy t2)
- x2 <- genIds t1
- x3 <- genIds t1
- x4 <- genIds ttape
- (_, e1') <- idana (x1 `SCons` x2 `SCons` x3 `SCons` x4 `SCons` env) ef
- (v2, e2') <- idana env ep
- (_, e3') <- idana env ezi
- (_, e4') <- idana env ebog
+ EFold1InnerD2 _ cm ef ez eplus ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ t2 = typeOf ez
+ xf1 <- genIds t2
+ xf2 <- genIds tB
+ (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef
+ (_, e2') <- idana env ez
+ xp1 <- genIds t2
+ xp2 <- genIds t2
+ (_, e3') <- idana (xp1 `SCons` xp2 `SCons` env) eplus
+ (v4, e4') <- idana env ebog
(_, e5') <- idana env ed
- let VIArr _ sh = v2
- res <- VIPair <$> genIds (fromSMTy t2) <*> (VIArr <$> genId <*> pure sh)
- pure (res, EFold1InnerD2 res cm t2 e1' e2' e3' e4' e5')
+ let VIArr _ sh = v4
+ res <- VIPair <$> genIds t2 <*> (VIArr <$> genId <*> pure sh)
+ pure (res, EFold1InnerD2 res cm e1' e2' e3' e4' e5')
EConst _ t val -> do
res <- VIScal <$> genId
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 25d26a6..93fabf9 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1133,9 +1133,11 @@ drev des accumMap sd = \case
let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
+ &. #px (auto1 @(D1 elt))
&. #pzi (auto1 @(ZeroInfo (D2 elt)))
&. #primal (primalTy `SCons` SNil)
&. #darr (auto1 @(TArr n sdElt))
@@ -1145,6 +1147,7 @@ drev des accumMap sd = \case
&. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
&. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
&. #ftape (auto1 @(Tape e_tape))
+ &. #primalzip (zipPrimalTy `SCons` SNil)
&. #efPrerebinds efPrerebinds
&. #propr (d1e envPro)
&. #d1env (desD1E des)
@@ -1166,11 +1169,14 @@ drev des accumMap sd = \case
(#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
layout)
ef0)) $
- EPair ext
- (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: layout))
- ef1)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout))))
+ elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: layout))
+ ef1) $
+ EPair ext
+ (evar IZ)
+ (EPair ext
+ (evar IZ)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout)))))
(EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
(SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
@@ -1181,19 +1187,24 @@ drev des accumMap sd = \case
(uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
makeAccumulators (autoWeak library #propr layout1) envPro $
let layout2 = #d2acPro :++: layout1 in
- EFold1InnerD2 ext commut (d2M eltty)
- (letBinds (efRebinds (IS (IS (IS IZ)))) $
- let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $
+ elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $
+ elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $
+ letBinds (efRebinds (IS (IS IZ))) $
+ let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in
elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
.> wPro (subList (bindingsBinds ef0) subtapeEf))
ef2) $
EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
- (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))
- (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)))
+ (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (ezip
+ (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
(ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
- (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
(EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
plus_x₀a_f
(weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index db66540..db7033d 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -148,7 +148,7 @@ interpret'Rec env = \case
arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
EFold1InnerD1 _ _ a b c -> do
let t = typeOf b
- let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a
+ let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
@@ -160,23 +160,25 @@ interpret'Rec env = \case
return (arrayMap fst res
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
arrayIndexLinear (snd (arrayIndex res idx)) i)
- EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do
- let STArr _ (STPair t1 ttape) = typeOf ebog
- let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef
- parr <- interpret' env ep
- zi <- interpret' env ezi
+ EFold1InnerD2 _ _ ef ez eplus ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ t2 = typeOf ez
+ let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef
+ zeroval <- interpret' env ez
+ let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus
bog <- interpret' env ebog
arrctg <- interpret' env ed
- let sh `ShCons` n = arrayShape parr
+ let sh `ShCons` n = arrayShape bog
res <- arrayGenerateM sh $ \idx -> do
let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
loop i !ctg !inpctgs = do
- let (prefix, tape) = arrayIndex bog (idx `IxCons` i)
- (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg
+ let b = arrayIndex bog (idx `IxCons` i)
+ (ctg1, ctg2) <- f b ctg
loop (i - 1) ctg1 (ctg2 : inpctgs)
(x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
- return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res)
+ x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res)
+ return (x0ctg
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
arrayIndexLinear (snd (arrayIndex res idx)) i)
EConst _ _ v -> return v
diff --git a/src/Simplify.hs b/src/Simplify.hs
index c1f92f1..aac9963 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -315,7 +315,7 @@ simplify'Rec = \case
EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
- EFold1InnerD2 _ cm t2 a b c d e -> [simprec| EFold1InnerD2 ext cm t2 *a *b *c *d *e |]
+ EFold1InnerD2 _ cm a b c d e -> [simprec| EFold1InnerD2 ext cm *a *b *c *d *e |]
EConst _ t v -> pure $ EConst ext t v
EIdx0 _ e -> [simprec| EIdx0 ext *e |]
EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
@@ -370,7 +370,7 @@ hasAdds = \case
EMaximum1Inner _ e -> hasAdds e
EMinimum1Inner _ e -> hasAdds e
EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
- EFold1InnerD2 _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
+ EFold1InnerD2 _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e