aboutsummaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
commit4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f (patch)
treee371c4962f1beee96cc68d55accffab16e18b97a /src/AST.hs
parent4d456e4d34b1e4fb3725051d1b8a0c376b704692 (diff)
Simplify foldD2 to not sum x0 contributions
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs42
1 files changed, 30 insertions, 12 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 7549ff0..663b83f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -71,21 +71,24 @@ data Expr x env t where
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
- -- MapAccum-like (is it real mapaccum? If so, rename)
+ -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
+ -- an implementation is allowed to parallelise this thing and store the b
+ -- values in some implementation-defined order.
+ -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
-> Expr x (t1 : t1 : env) (TPair t1 b)
-> Expr x env t1
-> Expr x env (TArr (S n) t1)
-> Expr x env (TPair (TArr n t1) -- normal primal fold output
(TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
- -- Reverse derivative of Efold1Inner.
- EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative
+ -- Reverse derivative of EFold1Inner. The contributions to the initial
+ -- element are not yet added together here; we assume a later fusion system
+ -- does that for us.
+ EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
-> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env t2 -- zero
- -> Expr x (t2 : t2 : env) t2 -- plus
- -> Expr x env (TArr (S n) b) -- extra data passed to function
+ -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
-> Expr x env (TArr n t2) -- incoming cotangent
- -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array
+ -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -237,7 +240,7 @@ typeOf = \case
EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
- EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2)
+ EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -287,7 +290,7 @@ extOf = \case
EMinimum1Inner x _ -> x
EReshape x _ _ _ -> x
EFold1InnerD1 x _ _ _ _ -> x
- EFold1InnerD2 x _ _ _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ -> x
EConst x _ _ -> x
EIdx0 x _ -> x
EIdx1 x _ _ -> x
@@ -336,7 +339,7 @@ travExt f = \case
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
@@ -398,7 +401,7 @@ subst' f w = \case
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
- EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e)
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
@@ -565,6 +568,19 @@ eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
eshapeConst ShNil = ENil ext
eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
+eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx
+eshapeProd SZ _ = EConst ext STI64 1
+eshapeProd (SS SZ) e = ESnd ext e
+eshapeProd (SS n) e =
+ eunPair e $ \_ e1 e2 ->
+ EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2)
+
+eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t)
+eflatten e =
+ let STArr n _ = typeOf e
+ in elet e $
+ EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ)
+
-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
-- ezeroD2 t ezi = EZero ext (d2M t) ezi
@@ -594,7 +610,9 @@ esnd e = ESnd ext e
elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
elet rhs body
| Dict <- styKnown (typeOf rhs)
- = ELet ext rhs body
+ = if cheapExpr rhs
+ then substInline rhs body
+ else ELet ext rhs body
-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
use :: Ex env a -> Ex env b -> Ex env b