aboutsummaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-28 11:56:40 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-28 11:56:40 +0100
commit955af83f664639701fdbee54718186e07b31d42f (patch)
tree30353d77c69b1dfdaf43797942dbf6e412a49450 /src/AST.hs
parent765b80616583322226284266605ab3a916da01db (diff)
Better fold D{1,2} primitives
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs30
1 files changed, 16 insertions, 14 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 2d4fd91..f7b63cf 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -70,17 +70,19 @@ data Expr x env t where
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1)
+ -- MapAccum-like (is it real mapaccum? If so, rename)
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
+ -> Expr x (t1 : t1 : env) (TPair t1 b)
+ -> Expr x env t1
+ -> Expr x env (TArr (S n) t1)
-> Expr x env (TPair (TArr n t1) -- normal primal fold output
- (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores)
- -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. Perhaps a hack to reduce the impact is to store ZeroInfos only?
+ (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
+ -- Reverse derivative of Efold1Inner.
EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative
- -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0
- -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true
- -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env (TArr (S n) t1) -- primal input array
- -> Expr x env (ZeroInfo t2)
- -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1
+ -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr x env t2 -- zero
+ -> Expr x (t2 : t2 : env) t2 -- plus
+ -> Expr x env (TArr (S n) b) -- extra data passed to function
-> Expr x env (TArr n t2) -- incoming cotangent
-> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array
@@ -232,8 +234,8 @@ typeOf = \case
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tape <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) (STPair t1 tape))
- EFold1InnerD2 _ _ t2 _ _ _ _ e5 | STArr n _ <- typeOf e5 -> STPair (fromSMTy t2) (STArr (SS n) (fromSMTy t2))
+ EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
+ EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2)
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -282,7 +284,7 @@ extOf = \case
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
EFold1InnerD1 x _ _ _ _ -> x
- EFold1InnerD2 x _ _ _ _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ _ _ -> x
EConst x _ _ -> x
EIdx0 x _ -> x
EIdx1 x _ _ -> x
@@ -330,7 +332,7 @@ travExt f = \case
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> pure t2 <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
+ EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
@@ -391,7 +393,7 @@ subst' f w = \case
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
- EFold1InnerD2 x cm t2 a b c d e -> EFold1InnerD2 x cm t2 (subst' (sinkF (sinkF (sinkF (sinkF f)))) (WCopy (WCopy (WCopy (WCopy w)))) a) (subst' f w b) (subst' f w c) (subst' f w d) (subst' f w e)
+ EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)