From 955af83f664639701fdbee54718186e07b31d42f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 Oct 2025 11:56:40 +0100 Subject: Better fold D{1,2} primitives --- src/AST.hs | 30 ++++++++-------- src/AST/Count.hs | 92 +++++++++++++++++++++++++++--------------------- src/AST/Pretty.hs | 22 ++++++------ src/AST/SplitLets.hs | 16 ++++----- src/AST/UnMonoid.hs | 2 +- src/Analysis/Identity.hs | 27 +++++++------- src/CHAD.hs | 35 +++++++++++------- src/Interpreter.hs | 22 ++++++------ src/Simplify.hs | 4 +-- 9 files changed, 140 insertions(+), 110 deletions(-) diff --git a/src/AST.hs b/src/AST.hs index 2d4fd91..f7b63cf 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -70,17 +70,19 @@ data Expr x env t where EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) + -- MapAccum-like (is it real mapaccum? If so, rename) + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (t1 : t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output - (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores) - -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. Perhaps a hack to reduce the impact is to store ZeroInfos only? + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of Efold1Inner. EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative - -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0 - -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true - -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env (TArr (S n) t1) -- primal input array - -> Expr x env (ZeroInfo t2) - -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1 + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env t2 -- zero + -> Expr x (t2 : t2 : env) t2 -- plus + -> Expr x env (TArr (S n) b) -- extra data passed to function -> Expr x env (TArr n t2) -- incoming cotangent -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array @@ -232,8 +234,8 @@ typeOf = \case EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tape <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) (STPair t1 tape)) - EFold1InnerD2 _ _ t2 _ _ _ _ e5 | STArr n _ <- typeOf e5 -> STPair (fromSMTy t2) (STArr (SS n) (fromSMTy t2)) + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -282,7 +284,7 @@ extOf = \case EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x EFold1InnerD1 x _ _ _ _ -> x - EFold1InnerD2 x _ _ _ _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -330,7 +332,7 @@ travExt f = \case EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> pure t2 <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e + EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -391,7 +393,7 @@ subst' f w = \case EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 x cm t2 (subst' (sinkF (sinkF (sinkF (sinkF f)))) (WCopy (WCopy (WCopy (WCopy w)))) a) (subst' f w b) (subst' f w c) (subst' f w d) (subst' f w e) + EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index d5afb5e..229661f 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -598,53 +598,65 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e - EFold1InnerD1{} -> - -- TODO: currently somewhat pessimistic on usage here + EFold1InnerD1 _ cm e1 e2 e3 -> case s of - SsPair' _ (SsArr' (SsPair' _ sTape)) -> - go topexpr sTape (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull sTape ))) s) - -- the primed patsyns also consume SsFull, so if the above doesn't match, - -- something along the way is SsNone - _ -> - go topexpr SsNone (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull SsNone))) s) - where - go :: Expr x env (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) - -> Substruc tape tape' - -> (forall env'. Ex env' (TPair (TArr n t1) (TArr (S n) (TPair t1 tape'))) -> Ex env' t') - -> r - go (EFold1InnerD1 _ cm a b c) sTape project = - occCountX (SsPair SsFull sTape) a $ \env1_2' mka -> - withSome (scaleMany (Some env1_2')) $ \env1_2 -> - occEnvPop' env1_2 $ \env1_1 _ -> - occEnvPop' env1_1 $ \env1 _ -> - occCountX SsFull b $ \env2 mkb -> - occCountX SsFull c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> - -- projectSmallerSubstruc SsFull s $ - project $ + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) (mkb env') (mkc env') - go _ _ _ = error "impossible" - - EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> - -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage - occCountX SsFull ef $ \env1_4' mkef -> - withSome (scaleMany (Some env1_4')) $ \env1_4 -> - occEnvPop' env1_4 $ \env1_3 _ -> - occEnvPop' env1_3 $ \env1_2 _ -> - occEnvPop' env1_2 $ \env1_1 _ -> - occEnvPop' env1_1 $ \env1 sTape -> - occCountX SsFull ep $ \env2 mkep -> - occCountX SsFull ezi $ \env3 mkezi -> - occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog -> + + EFold1InnerD2 _ cm ef ez eplus ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX SsFull ez $ \env2' mkez -> + occCountX SsFull eplus $ \env3_2' mkeplus -> + occEnvPop' env3_2' $ \env3_1' _ -> + occEnvPop' env3_1' $ \env3' _ -> + occCountX (SsArr sB) ebog $ \env4 mkebog -> occCountX SsFull ed $ \env5 mked -> - withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env -> + withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env -> k env $ \env' -> projectSmallerSubstruc SsFull s $ - EFold1InnerD2 ext cm t2 - (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull)) - (mkep env') (mkezi env') (mkebog env') (mked env') + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull)) + (mkebog env') (mked env') EConst _ t x -> k OccEnd $ \_ -> diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index afa62c6..587328d 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -245,21 +245,23 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] - EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do - let STArr _ (STPair t1 ttape) = typeOf ebog - name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef - name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef - name3 <- genNameIfUsedIn t1 (IS IZ) ef - name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef - ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef - ep' <- ppExpr' 11 val ep - ezi' <- ppExpr' 11 val ezi + EFold1InnerD2 _ cm ef ez eplus ebog ed -> do + let STArr _ tB = typeOf ebog + t2 = typeOf ez + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ez' <- ppExpr' 11 val ez + namep1 <- genNameIfUsedIn t2 (IS IZ) eplus + namep2 <- genNameIfUsedIn t2 IZ eplus + eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus ebog' <- ppExpr' 11 val ebog ed' <- ppExpr' 11 val ed let opname = "fold1iD2" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) - [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed'] + [ppLam [ppString namef1, ppString namef2] ef', ez' + ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed'] EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index f75e795..6034084 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -38,10 +38,10 @@ splitLets' = \sub -> \case EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD2 x cm t2 a b c d e -> - let STArr _ t1 = typeOf b - STArr _ (STPair _ ttape) = typeOf d - in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e) + EFold1InnerD2 x cm a b c d e -> + let t2 = typeOf b + STArr _ tB = typeOf d + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -106,10 +106,10 @@ splitLets' = \sub -> \case body -- TODO: abstract this to splitN lol wtf - split4 :: forall bind1 bind2 bind3 bind4 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t - split4 sub tbind1 tbind2 tbind3 tbind4 body = + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = let (ptrs1, bs1') = split @env' tbind1 (ptrs2, bs2') = split @(bind1 : env') tbind2 (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 555f0ec..6904715 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -45,7 +45,7 @@ unMonoid = \case EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) + EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 2fa156d..9dc8811 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -255,20 +255,21 @@ idana env expr = case expr of res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') pure (res, EFold1InnerD1 res cm e1' e2' e3') - EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do - let STArr _ (STPair t1 ttape) = typeOf ebog - x1 <- genIds (fromSMTy t2) - x2 <- genIds t1 - x3 <- genIds t1 - x4 <- genIds ttape - (_, e1') <- idana (x1 `SCons` x2 `SCons` x3 `SCons` x4 `SCons` env) ef - (v2, e2') <- idana env ep - (_, e3') <- idana env ezi - (_, e4') <- idana env ebog + EFold1InnerD2 _ cm ef ez eplus ebog ed -> do + let STArr _ tB = typeOf ebog + t2 = typeOf ez + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (_, e2') <- idana env ez + xp1 <- genIds t2 + xp2 <- genIds t2 + (_, e3') <- idana (xp1 `SCons` xp2 `SCons` env) eplus + (v4, e4') <- idana env ebog (_, e5') <- idana env ed - let VIArr _ sh = v2 - res <- VIPair <$> genIds (fromSMTy t2) <*> (VIArr <$> genId <*> pure sh) - pure (res, EFold1InnerD2 res cm t2 e1' e2' e3' e4' e5') + let VIArr _ sh = v4 + res <- VIPair <$> genIds t2 <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3' e4' e5') EConst _ t val -> do res <- VIScal <$> genId diff --git a/src/CHAD.hs b/src/CHAD.hs index 25d26a6..93fabf9 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1133,9 +1133,11 @@ drev des accumMap sd = \case let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) primalTy = STPair (STArr ndim (d1 eltty)) bogTy + zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) &. #pzi (auto1 @(ZeroInfo (D2 elt))) &. #primal (primalTy `SCons` SNil) &. #darr (auto1 @(TArr n sdElt)) @@ -1145,6 +1147,7 @@ drev des accumMap sd = \case &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) &. #ftape (auto1 @(Tape e_tape)) + &. #primalzip (zipPrimalTy `SCons` SNil) &. #efPrerebinds efPrerebinds &. #propr (d1e envPro) &. #d1env (desD1E des) @@ -1166,11 +1169,14 @@ drev des accumMap sd = \case (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) layout) ef0)) $ - EPair ext - (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: layout)) - ef1) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout)))) + elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: layout)) + ef1) $ + EPair ext + (evar IZ) + (EPair ext + (evar IZ) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) @@ -1181,19 +1187,24 @@ drev des accumMap sd = \case (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ makeAccumulators (autoWeak library #propr layout1) envPro $ let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut (d2M eltty) - (letBinds (efRebinds (IS (IS (IS IZ)))) $ - let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in + EFold1InnerD2 ext commut + (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ + elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ + elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ + letBinds (efRebinds (IS (IS IZ))) $ + let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 .> wPro (subList (bindingsBinds ef0) subtapeEf)) ef2) $ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) - (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)) - (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))) + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (ezip + (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) + (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ plus_x₀a_f (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ diff --git a/src/Interpreter.hs b/src/Interpreter.hs index db66540..db7033d 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -148,7 +148,7 @@ interpret'Rec env = \case arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) EFold1InnerD1 _ _ a b c -> do let t = typeOf b - let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr @@ -160,23 +160,25 @@ interpret'Rec env = \case return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) - EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do - let STArr _ (STPair t1 ttape) = typeOf ebog - let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef - parr <- interpret' env ep - zi <- interpret' env ezi + EFold1InnerD2 _ _ ef ez eplus ebog ed -> do + let STArr _ tB = typeOf ebog + t2 = typeOf ez + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + zeroval <- interpret' env ez + let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus bog <- interpret' env ebog arrctg <- interpret' env ed - let sh `ShCons` n = arrayShape parr + let sh `ShCons` n = arrayShape bog res <- arrayGenerateM sh $ \idx -> do let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) loop i !ctg !inpctgs = do - let (prefix, tape) = arrayIndex bog (idx `IxCons` i) - (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg loop (i - 1) ctg1 (ctg2 : inpctgs) (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) - return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res) + x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res) + return (x0ctg ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v diff --git a/src/Simplify.hs b/src/Simplify.hs index c1f92f1..aac9963 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -315,7 +315,7 @@ simplify'Rec = \case EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] - EFold1InnerD2 _ cm t2 a b c d e -> [simprec| EFold1InnerD2 ext cm t2 *a *b *c *d *e |] + EFold1InnerD2 _ cm a b c d e -> [simprec| EFold1InnerD2 ext cm *a *b *c *d *e |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] @@ -370,7 +370,7 @@ hasAdds = \case EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - EFold1InnerD2 _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e + EFold1InnerD2 _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e -- cgit v1.2.3-70-g09d2