aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/CHAD/AST.hs163
-rw-r--r--src/CHAD/AST/Bindings.hs6
-rw-r--r--src/CHAD/AST/Count.hs12
-rw-r--r--src/CHAD/AST/Pretty.hs16
-rw-r--r--src/CHAD/Analysis/Identity.hs6
-rw-r--r--src/CHAD/Drev.hs18
-rw-r--r--src/CHAD/Example/GMM.hs2
-rw-r--r--src/CHAD/Fusion.hs115
-rw-r--r--src/CHAD/Simplify.hs2
10 files changed, 238 insertions, 104 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 1eef3ed..834f1d7 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -44,6 +44,7 @@ library
CHAD.ForwardAD
CHAD.ForwardAD.DualNumbers
CHAD.ForwardAD.DualNumbers.Types
+ CHAD.Fusion
CHAD.Interpreter
-- CHAD.Interpreter.AccumOld
CHAD.Interpreter.Rep
@@ -58,6 +59,7 @@ library
base >= 4.19 && < 4.21,
containers,
deepseq,
+ dependent-map,
directory,
prettyprinter,
process,
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index be7f95e..51ed747 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
@@ -38,64 +39,64 @@ import CHAD.Drev.Types
-- intended to be eliminated after simplification, so that the input program as
-- well as the output program do not contain these constructors.
-- TODO: ensure this by a "stage" type parameter.
-type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
-data Expr x env t where
+type Expr :: ((Ty -> Type) -> [Ty] -> Ty -> Type) -> (Ty -> Type) -> [Ty] -> Ty -> Type
+data Expr fext x env t where
-- lambda calculus
- EVar :: x t -> STy t -> Idx env t -> Expr x env t
- ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
+ EVar :: x t -> STy t -> Idx env t -> Expr f x env t
+ ELet :: x t -> Expr f x env a -> Expr f x (a : env) t -> Expr f x env t
-- base types
- EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
- EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
- ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
- ENil :: x TNil -> Expr x env TNil
- EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
- EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
- ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
- ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
- EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
- EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
+ EPair :: x (TPair a b) -> Expr f x env a -> Expr f x env b -> Expr f x env (TPair a b)
+ EFst :: x a -> Expr f x env (TPair a b) -> Expr f x env a
+ ESnd :: x b -> Expr f x env (TPair a b) -> Expr f x env b
+ ENil :: x TNil -> Expr f x env TNil
+ EInl :: x (TEither a b) -> STy b -> Expr f x env a -> Expr f x env (TEither a b)
+ EInr :: x (TEither a b) -> STy a -> Expr f x env b -> Expr f x env (TEither a b)
+ ECase :: x c -> Expr f x env (TEither a b) -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c
+ ENothing :: x (TMaybe t) -> STy t -> Expr f x env (TMaybe t)
+ EJust :: x (TMaybe t) -> Expr f x env t -> Expr f x env (TMaybe t)
+ EMaybe :: x b -> Expr f x env b -> Expr f x (t : env) b -> Expr f x env (TMaybe t) -> Expr f x env b
-- array operations
- EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
- EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
+ EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr f x env (TArr n (TScal t))
+ EBuild :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x (Tup (Replicate n TIx) : env) t -> Expr f x env (TArr n t)
+ EMap :: x (TArr n t) -> Expr f x (a : env) t -> Expr f x env (TArr n a) -> Expr f x env (TArr n t)
-- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
- EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
- ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
- EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
- EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
+ EFold1Inner :: x (TArr n t) -> Commutative -> Expr f x (TPair t t : env) t -> Expr f x env t -> Expr f x env (TArr (S n) t) -> Expr f x env (TArr n t)
+ ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t))
+ EUnit :: x (TArr Z t) -> Expr f x env t -> Expr f x env (TArr Z t)
+ EReplicate1Inner :: x (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t) -> Expr f x env (TArr (S n) t)
+ EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t))
+ EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env (TArr m t) -> Expr f x env (TArr n t)
+ EZip :: x (TArr n (TPair a b)) -> Expr f x env (TArr n a) -> Expr f x env (TArr n b) -> Expr f x env (TArr n (TPair a b))
-- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
-- an implementation is allowed to parallelise this thing and store the b
-- values in some implementation-defined order.
-- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
- -> Expr x (TPair t1 t1 : env) (TPair t1 b)
- -> Expr x env t1
- -> Expr x env (TArr (S n) t1)
- -> Expr x env (TPair (TArr n t1) -- normal primal fold output
+ -> Expr f x (TPair t1 t1 : env) (TPair t1 b)
+ -> Expr f x env t1
+ -> Expr f x env (TArr (S n) t1)
+ -> Expr f x env (TPair (TArr n t1) -- normal primal fold output
(TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
-- Reverse derivative of EFold1Inner. The contributions to the initial
-- element are not yet added together here; we assume a later fusion system
-- does that for us.
EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
- -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
- -> Expr x env (TArr n t2) -- incoming cotangent
- -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
+ -> Expr f x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr f x env (TArr (S n) b) -- stores from EFold1InnerD1
+ -> Expr f x env (TArr n t2) -- incoming cotangent
+ -> Expr f x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
-- expression operations
- EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
- EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
- EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
- EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
- EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
+ EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr f x env (TScal t)
+ EIdx0 :: x t -> Expr f x env (TArr Z t) -> Expr f x env t
+ EIdx1 :: x (TArr n t) -> Expr f x env (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t)
+ EIdx :: x t -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env t
+ EShape :: x (Tup (Replicate n TIx)) -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx))
+ EOp :: x t -> SOp a t -> Expr f x env a -> Expr f x env t
-- custom derivatives
-- 'b' is the part of the input of the operation that derivatives should
@@ -106,43 +107,49 @@ data Expr x env t where
-- currently not used very much, so could be relaxed in the future; be sure
-- to check this requirement whenever it is necessary for soundness!
ECustom :: x t -> STy a -> STy b -> STy tape
- -> Expr x [b, a] t -- ^ regular operation
- -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
- -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
- -> Expr x env a -> Expr x env b
- -> Expr x env t
+ -> Expr f x [b, a] t -- ^ regular operation
+ -> Expr f x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr f x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr f x env a -> Expr f x env b
+ -> Expr f x env t
-- fake halfway checkpointing
- ERecompute :: x t -> Expr x env t -> Expr x env t
+ ERecompute :: x t -> Expr f x env t -> Expr f x env t
-- accumulation effect on monoids
-- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
-- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
-- need to create any zeros.
- EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ EWith :: x (TPair a t) -> SMTy t -> Expr f x env t -> Expr f x (TAccum t : env) a -> Expr f x env (TPair a t)
-- The 'Sparse' here is eliminated to dense by UnMonoid.
- EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxD p t) -> Sparse a b -> Expr f x env b -> Expr f x env (TAccum t) -> Expr f x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
- EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
- EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
- EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+ EZero :: x t -> SMTy t -> Expr f x env (ZeroInfo t) -> Expr f x env t
+ EDeepZero :: x t -> SMTy t -> Expr f x env (DeepZeroInfo t) -> Expr f x env t
+ EPlus :: x t -> SMTy t -> Expr f x env t -> Expr f x env t -> Expr f x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxS p t) -> Expr f x env a -> Expr f x env t
-- interface of abstract monoidal types
- ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
- ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
- ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
- ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr f x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr f x env a -> Expr f x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr f x env b -> Expr f x env (TLEither a b)
+ ELCase :: x c -> Expr f x env (TLEither a b) -> Expr f x env c -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c
-- partiality
- EError :: x a -> STy a -> String -> Expr x env a
-deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+ EError :: x a -> STy a -> String -> Expr f x env a
+
+ -- extension point
+ EExt :: x a -> STy a -> !(f x env a) -> Expr f x env a
+deriving instance (forall ty. Show (x ty), forall env' ty. Show (f x env' ty)) => Show (Expr f x env t)
+
+data NoExt x env a
+ deriving (Show)
-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full
-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@);
-- 'Ex' sets this to nothing.
-type Ex = Expr (Const ())
+type Ex = Expr NoExt (Const ())
ext :: Const () a
ext = Const ()
@@ -211,7 +218,7 @@ opt2 = \case
OIDiv t -> STScal t
OMod t -> STScal t
-typeOf :: Expr x env t -> STy t
+typeOf :: Expr f x env t -> STy t
typeOf = \case
EVar _ t _ -> t
ELet _ _ e -> typeOf e
@@ -266,7 +273,9 @@ typeOf = \case
EError _ t _ -> t
-extOf :: Expr x env t -> x t
+ EExt _ t _ -> t
+
+extOf :: Expr f x env t -> x t
extOf = \case
EVar x _ _ -> x
ELet x _ _ -> x
@@ -312,12 +321,13 @@ extOf = \case
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
+ EExt x _ _ -> x
-mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt :: TravExtEExt f => (forall a. x a -> x' a) -> Expr f x env t -> Expr f x' env t
mapExt f = runIdentity . travExt (Identity . f)
-{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
-travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt :: (Applicative f, TravExtEExt fe)
+ => (forall a. x a -> f (x' a)) -> Expr fe x env t -> f (Expr fe x' env t)
travExt f = \case
EVar x t i -> EVar <$> f x <*> pure t <*> pure i
ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
@@ -363,8 +373,15 @@ travExt f = \case
EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
EError x t s -> EError <$> f x <*> pure t <*> pure s
+ EExt x t v -> EExt <$> f x <*> pure t <*> travExtEExt f v
+
+class TravExtEExt fe where
+ travExtEExt :: Applicative f => (forall a. x a -> f (x' a)) -> fe x env t -> f (fe x' env t)
+
+instance TravExtEExt NoExt where
+ travExtEExt _ v = case v of {}
-substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+substInline :: Expr NoExt x env a -> Expr NoExt x (a : env) t -> Expr NoExt x env t
substInline repl =
subst $ \x t -> \case IZ -> repl
IS i -> EVar x t i
@@ -374,14 +391,14 @@ subst0 repl =
subst $ \_ t -> \case IZ -> repl
IS i -> EVar ext t (IS i)
-subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
- -> Expr x env t -> Expr x env' t
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr NoExt x env' a)
+ -> Expr NoExt x env t -> Expr NoExt x env' t
subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
-subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr NoExt x env2 a)
-> env' :> envOut
- -> Expr x env t
- -> Expr x envOut t
+ -> Expr NoExt x env t
+ -> Expr NoExt x envOut t
subst' f w = \case
EVar x t i -> f x t w i
ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
@@ -428,13 +445,13 @@ subst' f w = \case
EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
where
- sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr f x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr f x env2 t
sinkF f' x' t w' = \case
IZ -> EVar x' t (w' @> IZ)
IS i -> f' x' t (WPop w') i
-weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
+weakenExpr :: env :> env' -> Expr NoExt x env t -> Expr NoExt x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
class KnownScalTy t where knownScalTy :: SScalTy t
@@ -495,7 +512,7 @@ envKnown :: SList STy env -> Dict (KnownEnv env)
envKnown SNil = Dict
envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
-cheapExpr :: Expr x env t -> Bool
+cheapExpr :: Expr f x env t -> Bool
cheapExpr = \case
EVar{} -> True
ENil{} -> True
diff --git a/src/CHAD/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs
index c1a1e77..3ecda3e 100644
--- a/src/CHAD/AST/Bindings.hs
+++ b/src/CHAD/AST/Bindings.hs
@@ -28,7 +28,7 @@ data Bindings f env binds where
deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
infixl `BPush`
-bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds)
+bpush :: Bindings (Expr NoExt x) env binds -> Expr NoExt x (Append binds env) t -> Bindings (Expr NoExt x) env (t : binds)
bpush b e = b `BPush` (typeOf e, e)
infixl `bpush`
@@ -47,8 +47,8 @@ weakenBindings wf w (BPush b (t, x)) =
in (BPush b' (t, wf w' x), WCopy w')
weakenBindingsE :: env1 :> env2
- -> Bindings (Expr x) env1 binds
- -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2)
+ -> Bindings (Expr NoExt x) env1 binds
+ -> (Bindings (Expr NoExt x) env2 binds, Append binds env1 :> Append binds env2)
weakenBindingsE = weakenBindings weakenExpr
weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs
index 46173d2..1dad758 100644
--- a/src/CHAD/AST/Count.hs
+++ b/src/CHAD/AST/Count.hs
@@ -338,15 +338,15 @@ envMaskPrj (EMRest b) _ = b
envMaskPrj (_ `EMPush` b) IZ = b
envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i
-occCount :: Idx env a -> Expr x env t -> Occ
+occCount :: Idx env a -> Expr NoExt x env t -> Occ
occCount idx ex
| Some env <- occCountAll ex
= fst (occEnvPrj env idx)
-occCountAll :: Expr x env t -> Some (OccEnv Occ env)
+occCountAll :: Expr NoExt x env t -> Some (OccEnv Occ env)
occCountAll ex = occCountX SsFull ex $ \env _ -> Some env
-pruneExpr :: SList f env -> Expr x env t -> Ex env t
+pruneExpr :: SList f env -> Expr NoExt x env t -> Ex env t
pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
where
fullOccEnv :: SList f env -> OccEnv () env env
@@ -365,7 +365,7 @@ pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
-- occurrence counts. The callback reconstructs a new expression in an
-- updated "response" environment. The response must be at least as large as
-- the computed usages.
-occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t
+occCountX :: forall env t t' x r. Substruc t t' -> Expr NoExt x env t
-> (forall env'. OccEnv Occ env env'
-- response OccEnv must be at least as large as the OccEnv returned above
-> (forall env''. OccEnv () env env'' -> Ex env'' t')
@@ -885,7 +885,7 @@ occCountX initialS topexpr k = case topexpr of
handleReduction :: t ~ TArr n (TScal t2)
=> (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
- -> Expr x env (TArr (S n) (TScal t2))
+ -> Expr NoExt x env (TArr (S n) (TScal t2))
-> r
handleReduction reduce e
| STArr (SS n) _ <- typeOf e =
@@ -914,7 +914,7 @@ deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k =
case count of Zero -> k (SENo sub)
_ -> k (SEYesR sub)
-unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
+unsafeWeakenWithSubenv :: Subenv env env' -> Expr NoExt x env t -> Expr NoExt x env' t
unsafeWeakenWithSubenv = \sub ->
subst (\x t i -> case sinkViaSubenv i sub of
Just i' -> EVar x t i'
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
index 9ddcb35..b763efe 100644
--- a/src/CHAD/AST/Pretty.hs
+++ b/src/CHAD/AST/Pretty.hs
@@ -63,20 +63,20 @@ nameBaseForType _ = "x"
genName' :: String -> M String
genName' prefix = (prefix ++) . show <$> genId
-genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr NoExt x env t -> M String
genNameIfUsedIn' prefix ty idx ex
| occCount idx ex == mempty = case ty of STNil -> return "()"
_ -> return "_"
| otherwise = genName' prefix
-- TODO: let this return a type-tagged thing so that name environments are more typed than Const
-genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn :: STy a -> Idx env a -> Expr NoExt x env t -> M String
genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
-pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
+pprintExpr :: (KnownEnv env, PrettyX x) => Expr NoExt x env t -> IO ()
pprintExpr = putStrLn . ppExpr knownEnv
-ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
+ppExpr :: PrettyX x => SList STy env -> Expr NoExt x env t -> String
ppExpr senv e = render $ fst . flip runM 1 $ do
val <- mkVal senv
e' <- ppExpr' 0 val e
@@ -94,7 +94,7 @@ ppExpr senv e = render $ fst . flip runM 1 $ do
name <- genName' "arg"
return (Const name `SCons` val)
-ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
+ppExpr' :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc
ppExpr' d val expr = case expr of
EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr
@@ -374,9 +374,9 @@ ppExpr' d val expr = case expr of
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
-ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
+ppExprLet :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc
ppExprLet d val etop = do
- let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
+ let collect :: PrettyX x => SVal env -> Expr NoExt x env t -> M ([(String, Occ, ADoc)], ADoc)
collect val' (ELet _ rhs body) = do
let occ = occCount IZ body
name <- genNameIfUsedIn (typeOf rhs) IZ body
@@ -426,7 +426,7 @@ ppCommut :: Commutative -> String
ppCommut Commut = "(C)"
ppCommut Noncommut = ""
-ppX :: PrettyX x => Expr x env t -> ADoc
+ppX :: PrettyX x => Expr NoExt x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
data Fixity = Prefix | Infix
diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs
index 212cc7d..b637f88 100644
--- a/src/CHAD/Analysis/Identity.hs
+++ b/src/CHAD/Analysis/Identity.hs
@@ -63,15 +63,15 @@ validSplitEither (VIEither (Right v)) = (Nothing, Just v)
validSplitEither (VIEither' v1 v2) = (Just v1, Just v2)
-- | Symbolic partial evaluation.
-identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
+identityAnalysis :: SList STy env -> Expr NoExt x env t -> Expr NoExt ValId env t
identityAnalysis env term = runIdGen 0 $ do
env' <- slistMapA genIds env
snd <$> idana env' term
-identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t
+identityAnalysis' :: SList ValId env -> Expr NoExt x env t -> Expr NoExt ValId env t
identityAnalysis' env term = snd (runIdGen 0 (idana env term))
-idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
+idana :: SList ValId env -> Expr NoExt x env t -> IdGen (ValId t, Expr NoExt ValId env t)
idana env expr = case expr of
EVar _ t i -> do
let v = slistIdx env i
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index bfa964b..eba3719 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -726,7 +726,7 @@ drev :: forall env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> Sparse (D2 t) sd
- -> Expr ValId env t -> Ret env sto sd t
+ -> Expr NoExt ValId env t -> Ret env sto sd t
drev des _ sd | isAbsent sd =
\e ->
Ret BTop
@@ -774,7 +774,7 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(ENil ext)
- ELet _ (rhs :: Expr _ _ a) body
+ ELet _ (rhs :: Expr _ _ _ a) body
| ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
, RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
, Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
@@ -872,7 +872,7 @@ drev des accumMap sd = \case
(EError ext (contribTupTy des sub') "inr<-dinl")
(inj1 $ weakenExpr (WCopy WSink) e2))
- ECase _ e (a :: Expr _ _ t) b
+ ECase _ e (a :: Expr _ _ _ t) b
| STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
@@ -1041,7 +1041,7 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty)
+ EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ _ eltty)
| SpArr @_ @sdElt sdElt <- sd
, let eltty = typeOf ef
, shty :: STy shty <- tTup (sreplicate ndim tIx)
@@ -1081,7 +1081,7 @@ drev des accumMap sd = \case
(#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap _ ef (earr :: Expr _ _ (TArr n a))
+ EMap _ ef (earr :: Expr _ _ _ (TArr n a))
| SpArr sdElt <- sd
, let STArr ndim t1 = typeOf earr
t2 = typeOf ef ->
@@ -1391,7 +1391,7 @@ deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
-> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> Sparse (D2s t) sd
- -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+ -> Expr NoExt ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
deriv_extremum extremum des accumMap sd e
| at@(STArr (SS n) t@(STScal st)) <- typeOf e
, let at' = STArr n t
@@ -1437,7 +1437,7 @@ drevScoped :: forall a s env sto sd t.
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> STy a -> Storage s -> Maybe (ValId a)
-> Sparse (D2 t) sd
- -> Expr ValId (a : env) t
+ -> Expr NoExt ValId (a : env) t
-> RetScoped env sto a s sd t
drevScoped des accumMap argty argsto argids sd expr = case argsto of
SMerge
@@ -1496,7 +1496,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
-> VarMap Int (D2AcE (Select env sto "accum"))
-> (STy a, Storage s)
-> Sparse (D2 t) dt
- -> Expr ValId (a : env) t
+ -> Expr NoExt ValId (a : env) t
-> (forall provars shbinds tape d2a'.
SList STy provars
-> Subenv (D2E (Select env sto "merge")) (D2E provars)
@@ -1574,7 +1574,7 @@ drevLambda des accumMap (argty, argsto) sd origef k =
prf1 _ _ SDiscr = Refl
-- TODO: proper primal-only transform that doesn't depend on D1 = Id
-drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal :: Descr env sto -> Expr NoExt x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
| Refl <- d1Identity (typeOf e)
, Refl <- d1eIdentity (descrList des)
diff --git a/src/CHAD/Example/GMM.hs b/src/CHAD/Example/GMM.hs
index 18641e8..2b2ac2b 100644
--- a/src/CHAD/Example/GMM.hs
+++ b/src/CHAD/Example/GMM.hs
@@ -112,7 +112,7 @@ gmmObjective wrong = fromNamed $
qmat q l = inline qmat' (SNil .$ q .$ l)
in let_ #k2arr (unit #k2) $
- #k1
- + idx0 (sum1i (build1 #N $ #i :->
+ + idx0 (sum1i (build1 #N $ #i :-> recompute $
logsumexp (build1 #K $ #k :->
#alpha ! pair nil #k
+ idx0 (sum1i (#Q .! #k))
diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs
new file mode 100644
index 0000000..757667f
--- /dev/null
+++ b/src/CHAD/Fusion.hs
@@ -0,0 +1,115 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Fusion where
+
+import Data.Dependent.Map (DMap)
+-- import Data.Dependent.Map qualified as DMap
+import Data.Functor.Const
+import Data.Kind (Type)
+import Numeric.Natural
+
+import CHAD.AST
+import CHAD.AST.Bindings
+import CHAD.Data
+
+
+-- TODO:
+-- A bunch of data types are defined here that should be able to express a
+-- graph of loop nests. A graph is a straight-line program whose statements
+-- are, in this case, loop nests. A loop nest corresponds to what fusion
+-- normally calls a "cluster", but is here represented as, well, a loop nest.
+--
+-- No unzipping is done here, as I don't think it is necessary: I haven't been
+-- able to think of programs that get more fusion opportunities when unzipped
+-- than when zipped. If any such programs exist, I in any case conjecture that
+-- with a pre-pass that splits array operations that can be unzipped already at
+-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1),
+-- build n (\i -> E2))), all such fusion opportunities can be recovered. If
+-- this conjecture is false, some reorganisation may be required.
+--
+-- Next steps, perhaps:
+-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
+-- specifically but its fields. It will have a single output, and its args
+-- will be its list of free variables.
+-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
+-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
+-- a large expression that contain only "simple" AST constructors, and
+-- replaces those subterms with an EExt containing that graph. In this
+-- construction process, EBuild and ESum1Inner should be replaced with
+-- FLoop.
+-- 4. Implement fusion somehow on graphs!
+-- 5. Add an AST constructor for a loop nest (which most of the modules throw
+-- an error on, except Count, Simplify and Compile -- good luck with Count),
+-- and compile that to an actual C loop nest.
+-- 6. Extend to other cool operations like EFold1InnerD1
+
+
+type FEx = Expr FGraph (Const ())
+
+type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data FGraph x env t where
+ FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
+
+data Node env t where
+ FFreeVar :: STy t -> Idx env t -> Node env t
+ FLoop :: SList NodeId args
+ -> SList STy outs
+ -> LoopNest args outs
+ -> Tuple (Idx outs) t
+ -> Node env t
+
+data NodeId t = NodeId Natural (STy t)
+ deriving (Show)
+
+data Tuple f t where
+ TupNil :: Tuple f TNil
+ TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
+ TupSingle :: f t -> Tuple f t
+deriving instance (forall a. Show (f a)) => Show (Tuple f t)
+
+data LoopNest args outs where
+ Inner :: Bindings Ex args bs
+ -> SList (Idx (Append bs args)) outs
+ -> LoopNest args outs
+ -- this should be able to express a simple nesting of builds and sums.
+ Layer :: Bindings Ex args bs1
+ -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
+ -> LoopNest (TIx : Append bs1 args) loopouts
+ -> Partition BuildUp RedSum loopouts mapouts sumouts
+ -> Bindings Ex (Append sumouts (Append bs1 args)) bs2
+ -> SList (Idx (Append bs2 args)) outs
+ -> LoopNest args (Append outs mapouts)
+
+type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
+data Partition f1 f2 ts ts1 ts2 where
+ PNil :: Partition f1 f2 '[] '[] '[]
+ Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
+ Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
+
+data BuildUp t t' where
+ BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
+
+data RedSum t t' where
+ RedSum :: SMTy t -> RedSum t t
+
+-- type family Unzip t where
+-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
+-- Unzip (TArr n t) = UnzipA n t
+
+-- type family UnzipA n t where
+-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
+-- UnzipA n t = TArr n t
+
+-- data Zipping ut t where
+-- ZId :: Zipping t t
+-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
+-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
+-- deriving instance Show (Zipping ut t)
+
+
diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs
index ea253d6..a09effc 100644
--- a/src/CHAD/Simplify.hs
+++ b/src/CHAD/Simplify.hs
@@ -364,7 +364,7 @@ simplify'Rec = \case
-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
-hasAdds :: Expr x env t -> Bool
+hasAdds :: Expr NoExt x env t -> Bool
hasAdds = \case
EVar _ _ _ -> False
ELet _ rhs body -> hasAdds rhs || hasAdds body