aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-22 22:41:09 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-22 22:41:09 +0100
commit9b7c3eea7e34f5eb0d91f93b803e853028c2cec8 (patch)
tree25b906bb49218d2743631d0c83e23717012e3b9b /src/CHAD/AST.hs
parentb4f07c673b7c710f5861bb84e67233c63336c53d (diff)
WIP: Think about fusionfusion
Diffstat (limited to 'src/CHAD/AST.hs')
-rw-r--r--src/CHAD/AST.hs163
1 files changed, 90 insertions, 73 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index be7f95e..51ed747 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
@@ -38,64 +39,64 @@ import CHAD.Drev.Types
-- intended to be eliminated after simplification, so that the input program as
-- well as the output program do not contain these constructors.
-- TODO: ensure this by a "stage" type parameter.
-type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
-data Expr x env t where
+type Expr :: ((Ty -> Type) -> [Ty] -> Ty -> Type) -> (Ty -> Type) -> [Ty] -> Ty -> Type
+data Expr fext x env t where
-- lambda calculus
- EVar :: x t -> STy t -> Idx env t -> Expr x env t
- ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
+ EVar :: x t -> STy t -> Idx env t -> Expr f x env t
+ ELet :: x t -> Expr f x env a -> Expr f x (a : env) t -> Expr f x env t
-- base types
- EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
- EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
- ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
- ENil :: x TNil -> Expr x env TNil
- EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
- EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
- ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
- ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
- EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
- EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
+ EPair :: x (TPair a b) -> Expr f x env a -> Expr f x env b -> Expr f x env (TPair a b)
+ EFst :: x a -> Expr f x env (TPair a b) -> Expr f x env a
+ ESnd :: x b -> Expr f x env (TPair a b) -> Expr f x env b
+ ENil :: x TNil -> Expr f x env TNil
+ EInl :: x (TEither a b) -> STy b -> Expr f x env a -> Expr f x env (TEither a b)
+ EInr :: x (TEither a b) -> STy a -> Expr f x env b -> Expr f x env (TEither a b)
+ ECase :: x c -> Expr f x env (TEither a b) -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c
+ ENothing :: x (TMaybe t) -> STy t -> Expr f x env (TMaybe t)
+ EJust :: x (TMaybe t) -> Expr f x env t -> Expr f x env (TMaybe t)
+ EMaybe :: x b -> Expr f x env b -> Expr f x (t : env) b -> Expr f x env (TMaybe t) -> Expr f x env b
-- array operations
- EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
- EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
+ EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr f x env (TArr n (TScal t))
+ EBuild :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x (Tup (Replicate n TIx) : env) t -> Expr f x env (TArr n t)
+ EMap :: x (TArr n t) -> Expr f x (a : env) t -> Expr f x env (TArr n a) -> Expr f x env (TArr n t)
-- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
- EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
- ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
- EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
- EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
+ EFold1Inner :: x (TArr n t) -> Commutative -> Expr f x (TPair t t : env) t -> Expr f x env t -> Expr f x env (TArr (S n) t) -> Expr f x env (TArr n t)
+ ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t))
+ EUnit :: x (TArr Z t) -> Expr f x env t -> Expr f x env (TArr Z t)
+ EReplicate1Inner :: x (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t) -> Expr f x env (TArr (S n) t)
+ EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t))
+ EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env (TArr m t) -> Expr f x env (TArr n t)
+ EZip :: x (TArr n (TPair a b)) -> Expr f x env (TArr n a) -> Expr f x env (TArr n b) -> Expr f x env (TArr n (TPair a b))
-- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
-- an implementation is allowed to parallelise this thing and store the b
-- values in some implementation-defined order.
-- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
- -> Expr x (TPair t1 t1 : env) (TPair t1 b)
- -> Expr x env t1
- -> Expr x env (TArr (S n) t1)
- -> Expr x env (TPair (TArr n t1) -- normal primal fold output
+ -> Expr f x (TPair t1 t1 : env) (TPair t1 b)
+ -> Expr f x env t1
+ -> Expr f x env (TArr (S n) t1)
+ -> Expr f x env (TPair (TArr n t1) -- normal primal fold output
(TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
-- Reverse derivative of EFold1Inner. The contributions to the initial
-- element are not yet added together here; we assume a later fusion system
-- does that for us.
EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
- -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
- -> Expr x env (TArr n t2) -- incoming cotangent
- -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
+ -> Expr f x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr f x env (TArr (S n) b) -- stores from EFold1InnerD1
+ -> Expr f x env (TArr n t2) -- incoming cotangent
+ -> Expr f x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
-- expression operations
- EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
- EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
- EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
- EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
- EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
+ EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr f x env (TScal t)
+ EIdx0 :: x t -> Expr f x env (TArr Z t) -> Expr f x env t
+ EIdx1 :: x (TArr n t) -> Expr f x env (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t)
+ EIdx :: x t -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env t
+ EShape :: x (Tup (Replicate n TIx)) -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx))
+ EOp :: x t -> SOp a t -> Expr f x env a -> Expr f x env t
-- custom derivatives
-- 'b' is the part of the input of the operation that derivatives should
@@ -106,43 +107,49 @@ data Expr x env t where
-- currently not used very much, so could be relaxed in the future; be sure
-- to check this requirement whenever it is necessary for soundness!
ECustom :: x t -> STy a -> STy b -> STy tape
- -> Expr x [b, a] t -- ^ regular operation
- -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
- -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
- -> Expr x env a -> Expr x env b
- -> Expr x env t
+ -> Expr f x [b, a] t -- ^ regular operation
+ -> Expr f x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr f x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr f x env a -> Expr f x env b
+ -> Expr f x env t
-- fake halfway checkpointing
- ERecompute :: x t -> Expr x env t -> Expr x env t
+ ERecompute :: x t -> Expr f x env t -> Expr f x env t
-- accumulation effect on monoids
-- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
-- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
-- need to create any zeros.
- EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ EWith :: x (TPair a t) -> SMTy t -> Expr f x env t -> Expr f x (TAccum t : env) a -> Expr f x env (TPair a t)
-- The 'Sparse' here is eliminated to dense by UnMonoid.
- EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxD p t) -> Sparse a b -> Expr f x env b -> Expr f x env (TAccum t) -> Expr f x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
- EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
- EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
- EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+ EZero :: x t -> SMTy t -> Expr f x env (ZeroInfo t) -> Expr f x env t
+ EDeepZero :: x t -> SMTy t -> Expr f x env (DeepZeroInfo t) -> Expr f x env t
+ EPlus :: x t -> SMTy t -> Expr f x env t -> Expr f x env t -> Expr f x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxS p t) -> Expr f x env a -> Expr f x env t
-- interface of abstract monoidal types
- ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
- ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
- ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
- ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr f x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr f x env a -> Expr f x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr f x env b -> Expr f x env (TLEither a b)
+ ELCase :: x c -> Expr f x env (TLEither a b) -> Expr f x env c -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c
-- partiality
- EError :: x a -> STy a -> String -> Expr x env a
-deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+ EError :: x a -> STy a -> String -> Expr f x env a
+
+ -- extension point
+ EExt :: x a -> STy a -> !(f x env a) -> Expr f x env a
+deriving instance (forall ty. Show (x ty), forall env' ty. Show (f x env' ty)) => Show (Expr f x env t)
+
+data NoExt x env a
+ deriving (Show)
-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full
-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@);
-- 'Ex' sets this to nothing.
-type Ex = Expr (Const ())
+type Ex = Expr NoExt (Const ())
ext :: Const () a
ext = Const ()
@@ -211,7 +218,7 @@ opt2 = \case
OIDiv t -> STScal t
OMod t -> STScal t
-typeOf :: Expr x env t -> STy t
+typeOf :: Expr f x env t -> STy t
typeOf = \case
EVar _ t _ -> t
ELet _ _ e -> typeOf e
@@ -266,7 +273,9 @@ typeOf = \case
EError _ t _ -> t
-extOf :: Expr x env t -> x t
+ EExt _ t _ -> t
+
+extOf :: Expr f x env t -> x t
extOf = \case
EVar x _ _ -> x
ELet x _ _ -> x
@@ -312,12 +321,13 @@ extOf = \case
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
+ EExt x _ _ -> x
-mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt :: TravExtEExt f => (forall a. x a -> x' a) -> Expr f x env t -> Expr f x' env t
mapExt f = runIdentity . travExt (Identity . f)
-{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
-travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt :: (Applicative f, TravExtEExt fe)
+ => (forall a. x a -> f (x' a)) -> Expr fe x env t -> f (Expr fe x' env t)
travExt f = \case
EVar x t i -> EVar <$> f x <*> pure t <*> pure i
ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
@@ -363,8 +373,15 @@ travExt f = \case
EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
EError x t s -> EError <$> f x <*> pure t <*> pure s
+ EExt x t v -> EExt <$> f x <*> pure t <*> travExtEExt f v
+
+class TravExtEExt fe where
+ travExtEExt :: Applicative f => (forall a. x a -> f (x' a)) -> fe x env t -> f (fe x' env t)
+
+instance TravExtEExt NoExt where
+ travExtEExt _ v = case v of {}
-substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+substInline :: Expr NoExt x env a -> Expr NoExt x (a : env) t -> Expr NoExt x env t
substInline repl =
subst $ \x t -> \case IZ -> repl
IS i -> EVar x t i
@@ -374,14 +391,14 @@ subst0 repl =
subst $ \_ t -> \case IZ -> repl
IS i -> EVar ext t (IS i)
-subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
- -> Expr x env t -> Expr x env' t
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr NoExt x env' a)
+ -> Expr NoExt x env t -> Expr NoExt x env' t
subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
-subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr NoExt x env2 a)
-> env' :> envOut
- -> Expr x env t
- -> Expr x envOut t
+ -> Expr NoExt x env t
+ -> Expr NoExt x envOut t
subst' f w = \case
EVar x t i -> f x t w i
ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
@@ -428,13 +445,13 @@ subst' f w = \case
EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
where
- sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr f x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr f x env2 t
sinkF f' x' t w' = \case
IZ -> EVar x' t (w' @> IZ)
IS i -> f' x' t (WPop w') i
-weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
+weakenExpr :: env :> env' -> Expr NoExt x env t -> Expr NoExt x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
class KnownScalTy t where knownScalTy :: SScalTy t
@@ -495,7 +512,7 @@ envKnown :: SList STy env -> Dict (KnownEnv env)
envKnown SNil = Dict
envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
-cheapExpr :: Expr x env t -> Bool
+cheapExpr :: Expr f x env t -> Bool
cheapExpr = \case
EVar{} -> True
ENil{} -> True