From 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:58:08 +0100 Subject: Simplify foldD2 to not sum x0 contributions --- src/AST.hs | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) (limited to 'src/AST.hs') diff --git a/src/AST.hs b/src/AST.hs index 7549ff0..663b83f 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -71,21 +71,24 @@ data Expr x env t where EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) - -- MapAccum-like (is it real mapaccum? If so, rename) + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 b) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) - -- Reverse derivative of Efold1Inner. - EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env t2 -- zero - -> Expr x (t2 : t2 : env) t2 -- plus - -> Expr x env (TArr (S n) b) -- extra data passed to function + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 -> Expr x env (TArr n t2) -- incoming cotangent - -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -237,7 +240,7 @@ typeOf = \case EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) - EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -287,7 +290,7 @@ extOf = \case EMinimum1Inner x _ -> x EReshape x _ _ _ -> x EFold1InnerD1 x _ _ _ _ -> x - EFold1InnerD2 x _ _ _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -336,7 +339,7 @@ travExt f = \case EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -398,7 +401,7 @@ subst' f w = \case EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -565,6 +568,19 @@ eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) eshapeConst ShNil = ENil ext eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + -- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) -- ezeroD2 t ezi = EZero ext (d2M t) ezi @@ -594,7 +610,9 @@ esnd e = ESnd ext e elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) - = ELet ext rhs body + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body -- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) use :: Ex env a -> Ex env b -> Ex env b -- cgit v1.2.3-70-g09d2