diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-24 23:34:30 +0200 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-24 23:34:30 +0200 |
| commit | 42176d4a8a0fe7954f17da5c0506721695aa477f (patch) | |
| tree | 8a29e847faa613e9becf1bccdcaad010187e639b | |
| parent | 7729c45a325fe653421d654ed4c28b040585fce9 (diff) | |
WIP fold: everything but Compile (slow, but should be sound)
| -rw-r--r-- | src/AST.hs | 16 | ||||
| -rw-r--r-- | src/AST/Count.hs | 68 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 34 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 36 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 26 | ||||
| -rw-r--r-- | src/CHAD.hs | 20 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 4 | ||||
| -rw-r--r-- | src/Interpreter.hs | 44 | ||||
| -rw-r--r-- | src/Simplify.hs | 4 |
10 files changed, 222 insertions, 32 deletions
@@ -70,15 +70,16 @@ data Expr x env t where EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) t1)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores) - -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. - EFold1InnerD2 :: x (TArr (S n) t2) -> Commutative + -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. Perhaps a hack to reduce the impact is to store ZeroInfos only? + EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0 -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) -> Expr x env (TArr (S n) t1) -- primal input array + -> Expr x env (ZeroInfo t2) -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1 -> Expr x env (TArr n t2) -- incoming cotangent -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array @@ -231,6 +232,9 @@ typeOf = \case EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tape <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) (STPair t1 tape)) + EFold1InnerD2 _ _ t2 _ _ _ _ e5 | STArr n _ <- typeOf e5 -> STPair (fromSMTy t2) (STArr (SS n) (fromSMTy t2)) + EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t @@ -277,6 +281,8 @@ extOf = \case EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -323,6 +329,8 @@ travExt f = \case EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> pure t2 <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -382,6 +390,8 @@ subst' f w = \case EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 x cm t2 (subst' (sinkF (sinkF (sinkF (sinkF f)))) (WCopy (WCopy (WCopy (WCopy w)))) a) (subst' f w b) (subst' f w c) (subst' f w d) (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index bec5a9d..cb363a3 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -356,7 +356,14 @@ pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) fullOccEnv SNil = OccEnd fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull --- * s: how much of the result is required +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse -- * k: is passed the actual environment usage of this expression, including -- occurrence counts. The callback reconstructs a new expression in an -- updated "response" environment. The response must be at least as large as @@ -434,8 +441,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env1' $ \env1 s1 -> occEnvPop' env2' $ \env2 s2 -> occCountX (SsEither s1 s2) e $ \env0 mke -> - withSome (Some env1 <||> Some env2) $ \env12 -> - withSome (Some env12 <> Some env0) $ \env -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> k env $ \env' -> ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) ENothing _ t -> @@ -459,8 +465,7 @@ occCountX initialS topexpr k = case topexpr of occCountX s b $ \env2' mkb -> occEnvPop' env2' $ \env2 s2 -> occCountX (SsMaybe s2) e $ \env0 mke -> - withSome (Some env1 <||> Some env2) $ \env12 -> - withSome (Some env12 <> Some env0) $ \env -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> k env $ \env' -> EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') ELNil _ t1 t2 -> @@ -497,9 +502,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env2' $ \env2 s1 -> occEnvPop' env3' $ \env3 s2 -> occCountX (SsLEither s1 s2) e $ \env0 mke -> - withSome (Some env1 <||> Some env2) $ \env12 -> - withSome (Some env12 <||> Some env3) $ \env123 -> - withSome (Some env123 <> Some env0) $ \env -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> k env $ \env' -> ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) @@ -550,16 +553,11 @@ occCountX initialS topexpr k = case topexpr of occCountX (SsArr sElt) c $ \env3 mkc -> withSome (Some env1 <> Some env2 <> Some env3) $ \env -> k env $ \env' -> - let expr = EFold1Inner ext commut - (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) - (mkb env') (mkc env') in - case s of - SsNone -> use expr $ ENil ext - SsArr s' -> projectSmallerSubstruc (SsArr sElt) (SsArr s') expr - SsFull -> case testEquality sElt SsFull of - Just Refl -> expr - Nothing -> error "unreachable" + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush (OccPush env' () sElt) () sElt)) + (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -594,6 +592,40 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + EFold1InnerD1 _ cm a b c -> + -- TODO: currently maximally pessimistic on usage here; in particular, + -- usage tracking of the 'tape' stores can be improved + occCountX SsFull a $ \env1_2' mka -> + withSome (scaleMany (Some env1_2')) $ \env1_2 -> + occEnvPop' env1_2 $ \env1_1 _ -> + occEnvPop' env1_1 $ \env1 _ -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull c $ \env3 mkc -> + withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> + -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage + occCountX SsFull ef $ \env1_4' mkef -> + withSome (scaleMany (Some env1_4')) $ \env1_4 -> + occEnvPop' env1_4 $ \env1_3 _ -> + occEnvPop' env1_3 $ \env1_2 _ -> + occEnvPop' env1_2 $ \env1_1 _ -> + occEnvPop' env1_1 $ \env1 sTape -> + occCountX SsFull ep $ \env2 mkep -> + occCountX SsFull ezi $ \env3 mkezi -> + occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog -> + occCountX SsFull ed $ \env5 mked -> + withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm t2 + (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull)) + (mkep env') (mkezi env') (mkebog env') (mked env') + EConst _ t x -> k OccEnd $ \_ -> case s of diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 9018602..afa62c6 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -72,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -209,8 +210,7 @@ ppExpr' d val expr = case expr of a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] @@ -235,6 +235,32 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EFold1InnerD1 _ cm a b c -> do + name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a + name2 <- genNameIfUsedIn (typeOf b) IZ a + a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + + EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do + let STArr _ (STPair t1 ttape) = typeOf ebog + name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef + name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef + name3 <- genNameIfUsedIn t1 (IS IZ) ef + name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef + ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef + ep' <- ppExpr' 11 val ep + ezi' <- ppExpr' 11 val ezi + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -386,6 +412,10 @@ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s ppSparse (SMTScal _) SpScal = "." +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 82ec1d6..f75e795 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) @@ -35,6 +35,13 @@ splitLets' = \sub -> \case EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm t2 a b c d e -> + let STArr _ t1 = typeOf b + STArr _ (STPair _ ttape) = typeOf d + in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -98,6 +105,33 @@ splitLets' = \sub -> \case t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 48dd709..555f0ec 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -44,6 +44,8 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index b54946b..2fa156d 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -244,6 +244,32 @@ idana env expr = case expr of res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds t1 + x2 <- genIds t1 + (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do + let STArr _ (STPair t1 ttape) = typeOf ebog + x1 <- genIds (fromSMTy t2) + x2 <- genIds t1 + x3 <- genIds t1 + x4 <- genIds ttape + (_, e1') <- idana (x1 `SCons` x2 `SCons` x3 `SCons` x4 `SCons` env) ef + (v2, e2') <- idana env ep + (_, e3') <- idana env ezi + (_, e4') <- idana env ebog + (_, e5') <- idana env ed + let VIArr _ sh = v2 + res <- VIPair <$> genIds (fromSMTy t2) <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm t2 e1' e2' e3' e4' e5') + EConst _ t val -> do res <- VIScal <$> genId pure (res, EConst res t val) diff --git a/src/CHAD.hs b/src/CHAD.hs index ec719e8..25d26a6 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1136,6 +1136,7 @@ drev des accumMap sd = \case library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) &. #primal (primalTy `SCons` SNil) &. #darr (auto1 @(TArr n sdElt)) &. #d (auto1 @(D2 elt)) @@ -1157,23 +1158,25 @@ drev des accumMap sd = \case subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> Ret (bconcat bindsx₀a mergePrimalBindings' `bpush` weakenExpr wOverPrimalBindings ex₀1 - `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 `bpush` EFold1InnerD1 ext commut - (letBinds (fst (weakenBindingsE (autoWeak library + (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in + letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)) + layout) ef0)) $ EPair ext (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)) + (#fbinds :++: layout)) ef1) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)))) - (EVar ext (d1 eltty) (IS IZ)) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout)))) + (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) - (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) (EFst ext (EVar ext primalTy IZ)) subx₀af - (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ makeAccumulators (autoWeak library #propr layout1) envPro $ @@ -1187,6 +1190,7 @@ drev des accumMap sd = \case ef2) $ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) + (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)) (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)) (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index c98b6c0..467b895 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -194,9 +194,13 @@ dfwdDN = \case EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index ffc2929..db66540 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,8 @@ module Interpreter ( ) where import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) @@ -28,6 +30,7 @@ import Data.Functor.Identity import qualified Data.Functor.Product as Product import Data.Int (Int64) import Data.IORef +import Data.Tuple (swap) import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) @@ -143,6 +146,39 @@ interpret'Rec env = \case sh `ShCons` n = arrayShape arr numericIsNum t $ return $ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do + let STArr _ (STPair t1 ttape) = typeOf ebog + let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef + parr <- interpret' env ep + zi <- interpret' env ezi + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape parr + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let (prefix, tape) = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res) + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) @@ -411,3 +447,11 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y diff --git a/src/Simplify.hs b/src/Simplify.hs index 51870d4..c1f92f1 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -314,6 +314,8 @@ simplify'Rec = \case EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] + EFold1InnerD2 _ cm t2 a b c d e -> [simprec| EFold1InnerD2 ext cm t2 *a *b *c *d *e |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] @@ -367,6 +369,8 @@ hasAdds = \case EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e + EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1InnerD2 _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e |
