diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-24 23:34:30 +0200 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-24 23:34:30 +0200 |
| commit | 42176d4a8a0fe7954f17da5c0506721695aa477f (patch) | |
| tree | 8a29e847faa613e9becf1bccdcaad010187e639b /src/AST.hs | |
| parent | 7729c45a325fe653421d654ed4c28b040585fce9 (diff) | |
WIP fold: everything but Compile (slow, but should be sound)
Diffstat (limited to 'src/AST.hs')
| -rw-r--r-- | src/AST.hs | 16 |
1 files changed, 13 insertions, 3 deletions
@@ -70,15 +70,16 @@ data Expr x env t where EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) t1)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores) - -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. - EFold1InnerD2 :: x (TArr (S n) t2) -> Commutative + -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. Perhaps a hack to reduce the impact is to store ZeroInfos only? + EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0 -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) -> Expr x env (TArr (S n) t1) -- primal input array + -> Expr x env (ZeroInfo t2) -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1 -> Expr x env (TArr n t2) -- incoming cotangent -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array @@ -231,6 +232,9 @@ typeOf = \case EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tape <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) (STPair t1 tape)) + EFold1InnerD2 _ _ t2 _ _ _ _ e5 | STArr n _ <- typeOf e5 -> STPair (fromSMTy t2) (STArr (SS n) (fromSMTy t2)) + EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t @@ -277,6 +281,8 @@ extOf = \case EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -323,6 +329,8 @@ travExt f = \case EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> pure t2 <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -382,6 +390,8 @@ subst' f w = \case EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 x cm t2 (subst' (sinkF (sinkF (sinkF (sinkF f)))) (WCopy (WCopy (WCopy (WCopy w)))) a) (subst' f w b) (subst' f w c) (subst' f w d) (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) |
