diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-22 22:41:09 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-22 22:41:09 +0100 |
| commit | 9b7c3eea7e34f5eb0d91f93b803e853028c2cec8 (patch) | |
| tree | 25b906bb49218d2743631d0c83e23717012e3b9b /src/CHAD/AST.hs | |
| parent | b4f07c673b7c710f5861bb84e67233c63336c53d (diff) | |
WIP: Think about fusionfusion
Diffstat (limited to 'src/CHAD/AST.hs')
| -rw-r--r-- | src/CHAD/AST.hs | 163 |
1 files changed, 90 insertions, 73 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index be7f95e..51ed747 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -38,64 +39,64 @@ import CHAD.Drev.Types -- intended to be eliminated after simplification, so that the input program as -- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. -type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type -data Expr x env t where +type Expr :: ((Ty -> Type) -> [Ty] -> Ty -> Type) -> (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr fext x env t where -- lambda calculus - EVar :: x t -> STy t -> Idx env t -> Expr x env t - ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + EVar :: x t -> STy t -> Idx env t -> Expr f x env t + ELet :: x t -> Expr f x env a -> Expr f x (a : env) t -> Expr f x env t -- base types - EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) - EFst :: x a -> Expr x env (TPair a b) -> Expr x env a - ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b - ENil :: x TNil -> Expr x env TNil - EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) - EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) - ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) - EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) - EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + EPair :: x (TPair a b) -> Expr f x env a -> Expr f x env b -> Expr f x env (TPair a b) + EFst :: x a -> Expr f x env (TPair a b) -> Expr f x env a + ESnd :: x b -> Expr f x env (TPair a b) -> Expr f x env b + ENil :: x TNil -> Expr f x env TNil + EInl :: x (TEither a b) -> STy b -> Expr f x env a -> Expr f x env (TEither a b) + EInr :: x (TEither a b) -> STy a -> Expr f x env b -> Expr f x env (TEither a b) + ECase :: x c -> Expr f x env (TEither a b) -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c + ENothing :: x (TMaybe t) -> STy t -> Expr f x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr f x env t -> Expr f x env (TMaybe t) + EMaybe :: x b -> Expr f x env b -> Expr f x (t : env) b -> Expr f x env (TMaybe t) -> Expr f x env b -- array operations - EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr f x env (TArr n (TScal t)) + EBuild :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x (Tup (Replicate n TIx) : env) t -> Expr f x env (TArr n t) + EMap :: x (TArr n t) -> Expr f x (a : env) t -> Expr f x env (TArr n a) -> Expr f x env (TArr n t) -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) - ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) - EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) - EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr f x (TPair t t : env) t -> Expr f x env t -> Expr f x env (TArr (S n) t) -> Expr f x env (TArr n t) + ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t)) + EUnit :: x (TArr Z t) -> Expr f x env t -> Expr f x env (TArr Z t) + EReplicate1Inner :: x (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t) -> Expr f x env (TArr (S n) t) + EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t)) + EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env (TArr m t) -> Expr f x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr f x env (TArr n a) -> Expr f x env (TArr n b) -> Expr f x env (TArr n (TPair a b)) -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: -- an implementation is allowed to parallelise this thing and store the b -- values in some implementation-defined order. -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (TPair t1 t1 : env) (TPair t1 b) - -> Expr x env t1 - -> Expr x env (TArr (S n) t1) - -> Expr x env (TPair (TArr n t1) -- normal primal fold output + -> Expr f x (TPair t1 t1 : env) (TPair t1 b) + -> Expr f x env t1 + -> Expr f x env (TArr (S n) t1) + -> Expr f x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) -- Reverse derivative of EFold1Inner. The contributions to the initial -- element are not yet added together here; we assume a later fusion system -- does that for us. EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative - -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 - -> Expr x env (TArr n t2) -- incoming cotangent - -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array + -> Expr f x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr f x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr f x env (TArr n t2) -- incoming cotangent + -> Expr f x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations - EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) - EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t - EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t - EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) - EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr f x env (TScal t) + EIdx0 :: x t -> Expr f x env (TArr Z t) -> Expr f x env t + EIdx1 :: x (TArr n t) -> Expr f x env (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t) + EIdx :: x t -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env t + EShape :: x (Tup (Replicate n TIx)) -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx)) + EOp :: x t -> SOp a t -> Expr f x env a -> Expr f x env t -- custom derivatives -- 'b' is the part of the input of the operation that derivatives should @@ -106,43 +107,49 @@ data Expr x env t where -- currently not used very much, so could be relaxed in the future; be sure -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x [b, a] t -- ^ regular operation - -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative - -> Expr x env a -> Expr x env b - -> Expr x env t + -> Expr f x [b, a] t -- ^ regular operation + -> Expr f x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr f x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative + -> Expr f x env a -> Expr f x env b + -> Expr f x env t -- fake halfway checkpointing - ERecompute :: x t -> Expr x env t -> Expr x env t + ERecompute :: x t -> Expr f x env t -> Expr f x env t -- accumulation effect on monoids -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not -- need to create any zeros. - EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EWith :: x (TPair a t) -> SMTy t -> Expr f x env t -> Expr f x (TAccum t : env) a -> Expr f x env (TPair a t) -- The 'Sparse' here is eliminated to dense by UnMonoid. - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxD p t) -> Sparse a b -> Expr f x env b -> Expr f x env (TAccum t) -> Expr f x env TNil -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t - EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t - EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + EZero :: x t -> SMTy t -> Expr f x env (ZeroInfo t) -> Expr f x env t + EDeepZero :: x t -> SMTy t -> Expr f x env (DeepZeroInfo t) -> Expr f x env t + EPlus :: x t -> SMTy t -> Expr f x env t -> Expr f x env t -> Expr f x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxS p t) -> Expr f x env a -> Expr f x env t -- interface of abstract monoidal types - ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) - ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) - ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) - ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr f x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr f x env a -> Expr f x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr f x env b -> Expr f x env (TLEither a b) + ELCase :: x c -> Expr f x env (TLEither a b) -> Expr f x env c -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c -- partiality - EError :: x a -> STy a -> String -> Expr x env a -deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + EError :: x a -> STy a -> String -> Expr f x env a + + -- extension point + EExt :: x a -> STy a -> !(f x env a) -> Expr f x env a +deriving instance (forall ty. Show (x ty), forall env' ty. Show (f x env' ty)) => Show (Expr f x env t) + +data NoExt x env a + deriving (Show) -- | A (well-typed, well-scoped) expression using De Bruijn indices. The full -- 'Expr' type is parametrised on an indexed type of "additional info" (@x@); -- 'Ex' sets this to nothing. -type Ex = Expr (Const ()) +type Ex = Expr NoExt (Const ()) ext :: Const () a ext = Const () @@ -211,7 +218,7 @@ opt2 = \case OIDiv t -> STScal t OMod t -> STScal t -typeOf :: Expr x env t -> STy t +typeOf :: Expr f x env t -> STy t typeOf = \case EVar _ t _ -> t ELet _ _ e -> typeOf e @@ -266,7 +273,9 @@ typeOf = \case EError _ t _ -> t -extOf :: Expr x env t -> x t + EExt _ t _ -> t + +extOf :: Expr f x env t -> x t extOf = \case EVar x _ _ -> x ELet x _ _ -> x @@ -312,12 +321,13 @@ extOf = \case EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x + EExt x _ _ -> x -mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt :: TravExtEExt f => (forall a. x a -> x' a) -> Expr f x env t -> Expr f x' env t mapExt f = runIdentity . travExt (Identity . f) -{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} -travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt :: (Applicative f, TravExtEExt fe) + => (forall a. x a -> f (x' a)) -> Expr fe x env t -> f (Expr fe x' env t) travExt f = \case EVar x t i -> EVar <$> f x <*> pure t <*> pure i ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body @@ -363,8 +373,15 @@ travExt f = \case EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s + EExt x t v -> EExt <$> f x <*> pure t <*> travExtEExt f v + +class TravExtEExt fe where + travExtEExt :: Applicative f => (forall a. x a -> f (x' a)) -> fe x env t -> f (fe x' env t) + +instance TravExtEExt NoExt where + travExtEExt _ v = case v of {} -substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t +substInline :: Expr NoExt x env a -> Expr NoExt x (a : env) t -> Expr NoExt x env t substInline repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i @@ -374,14 +391,14 @@ subst0 repl = subst $ \_ t -> \case IZ -> repl IS i -> EVar ext t (IS i) -subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) - -> Expr x env t -> Expr x env' t +subst :: (forall a. x a -> STy a -> Idx env a -> Expr NoExt x env' a) + -> Expr NoExt x env t -> Expr NoExt x env' t subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId -subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr NoExt x env2 a) -> env' :> envOut - -> Expr x env t - -> Expr x envOut t + -> Expr NoExt x env t + -> Expr NoExt x envOut t subst' f w = \case EVar x t i -> f x t w i ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) @@ -428,13 +445,13 @@ subst' f w = \case EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where - sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr f x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr f x env2 t sinkF f' x' t w' = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i -weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr :: env :> env' -> Expr NoExt x env t -> Expr NoExt x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) class KnownScalTy t where knownScalTy :: SScalTy t @@ -495,7 +512,7 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict -cheapExpr :: Expr x env t -> Bool +cheapExpr :: Expr f x env t -> Bool cheapExpr = \case EVar{} -> True ENil{} -> True |
