diff options
Diffstat (limited to 'src/CHAD')
| -rw-r--r-- | src/CHAD/AST.hs | 163 | ||||
| -rw-r--r-- | src/CHAD/AST/Bindings.hs | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs | 12 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs | 16 | ||||
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs | 6 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs | 18 | ||||
| -rw-r--r-- | src/CHAD/Example/GMM.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Fusion.hs | 115 | ||||
| -rw-r--r-- | src/CHAD/Simplify.hs | 2 |
9 files changed, 236 insertions, 104 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index be7f95e..51ed747 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -38,64 +39,64 @@ import CHAD.Drev.Types -- intended to be eliminated after simplification, so that the input program as -- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. -type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type -data Expr x env t where +type Expr :: ((Ty -> Type) -> [Ty] -> Ty -> Type) -> (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr fext x env t where -- lambda calculus - EVar :: x t -> STy t -> Idx env t -> Expr x env t - ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + EVar :: x t -> STy t -> Idx env t -> Expr f x env t + ELet :: x t -> Expr f x env a -> Expr f x (a : env) t -> Expr f x env t -- base types - EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) - EFst :: x a -> Expr x env (TPair a b) -> Expr x env a - ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b - ENil :: x TNil -> Expr x env TNil - EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) - EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) - ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) - EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) - EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + EPair :: x (TPair a b) -> Expr f x env a -> Expr f x env b -> Expr f x env (TPair a b) + EFst :: x a -> Expr f x env (TPair a b) -> Expr f x env a + ESnd :: x b -> Expr f x env (TPair a b) -> Expr f x env b + ENil :: x TNil -> Expr f x env TNil + EInl :: x (TEither a b) -> STy b -> Expr f x env a -> Expr f x env (TEither a b) + EInr :: x (TEither a b) -> STy a -> Expr f x env b -> Expr f x env (TEither a b) + ECase :: x c -> Expr f x env (TEither a b) -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c + ENothing :: x (TMaybe t) -> STy t -> Expr f x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr f x env t -> Expr f x env (TMaybe t) + EMaybe :: x b -> Expr f x env b -> Expr f x (t : env) b -> Expr f x env (TMaybe t) -> Expr f x env b -- array operations - EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr f x env (TArr n (TScal t)) + EBuild :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x (Tup (Replicate n TIx) : env) t -> Expr f x env (TArr n t) + EMap :: x (TArr n t) -> Expr f x (a : env) t -> Expr f x env (TArr n a) -> Expr f x env (TArr n t) -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) - ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) - EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) - EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr f x (TPair t t : env) t -> Expr f x env t -> Expr f x env (TArr (S n) t) -> Expr f x env (TArr n t) + ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t)) + EUnit :: x (TArr Z t) -> Expr f x env t -> Expr f x env (TArr Z t) + EReplicate1Inner :: x (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t) -> Expr f x env (TArr (S n) t) + EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t)) + EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr f x env (TArr (S n) (TScal t)) -> Expr f x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env (TArr m t) -> Expr f x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr f x env (TArr n a) -> Expr f x env (TArr n b) -> Expr f x env (TArr n (TPair a b)) -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: -- an implementation is allowed to parallelise this thing and store the b -- values in some implementation-defined order. -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (TPair t1 t1 : env) (TPair t1 b) - -> Expr x env t1 - -> Expr x env (TArr (S n) t1) - -> Expr x env (TPair (TArr n t1) -- normal primal fold output + -> Expr f x (TPair t1 t1 : env) (TPair t1 b) + -> Expr f x env t1 + -> Expr f x env (TArr (S n) t1) + -> Expr f x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) -- Reverse derivative of EFold1Inner. The contributions to the initial -- element are not yet added together here; we assume a later fusion system -- does that for us. EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative - -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 - -> Expr x env (TArr n t2) -- incoming cotangent - -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array + -> Expr f x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr f x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr f x env (TArr n t2) -- incoming cotangent + -> Expr f x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations - EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) - EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t - EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t - EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) - EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr f x env (TScal t) + EIdx0 :: x t -> Expr f x env (TArr Z t) -> Expr f x env t + EIdx1 :: x (TArr n t) -> Expr f x env (TArr (S n) t) -> Expr f x env TIx -> Expr f x env (TArr n t) + EIdx :: x t -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx)) -> Expr f x env t + EShape :: x (Tup (Replicate n TIx)) -> Expr f x env (TArr n t) -> Expr f x env (Tup (Replicate n TIx)) + EOp :: x t -> SOp a t -> Expr f x env a -> Expr f x env t -- custom derivatives -- 'b' is the part of the input of the operation that derivatives should @@ -106,43 +107,49 @@ data Expr x env t where -- currently not used very much, so could be relaxed in the future; be sure -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x [b, a] t -- ^ regular operation - -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative - -> Expr x env a -> Expr x env b - -> Expr x env t + -> Expr f x [b, a] t -- ^ regular operation + -> Expr f x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr f x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative + -> Expr f x env a -> Expr f x env b + -> Expr f x env t -- fake halfway checkpointing - ERecompute :: x t -> Expr x env t -> Expr x env t + ERecompute :: x t -> Expr f x env t -> Expr f x env t -- accumulation effect on monoids -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not -- need to create any zeros. - EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EWith :: x (TPair a t) -> SMTy t -> Expr f x env t -> Expr f x (TAccum t : env) a -> Expr f x env (TPair a t) -- The 'Sparse' here is eliminated to dense by UnMonoid. - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxD p t) -> Sparse a b -> Expr f x env b -> Expr f x env (TAccum t) -> Expr f x env TNil -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t - EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t - EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + EZero :: x t -> SMTy t -> Expr f x env (ZeroInfo t) -> Expr f x env t + EDeepZero :: x t -> SMTy t -> Expr f x env (DeepZeroInfo t) -> Expr f x env t + EPlus :: x t -> SMTy t -> Expr f x env t -> Expr f x env t -> Expr f x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr f x env (AcIdxS p t) -> Expr f x env a -> Expr f x env t -- interface of abstract monoidal types - ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) - ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) - ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) - ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr f x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr f x env a -> Expr f x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr f x env b -> Expr f x env (TLEither a b) + ELCase :: x c -> Expr f x env (TLEither a b) -> Expr f x env c -> Expr f x (a : env) c -> Expr f x (b : env) c -> Expr f x env c -- partiality - EError :: x a -> STy a -> String -> Expr x env a -deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + EError :: x a -> STy a -> String -> Expr f x env a + + -- extension point + EExt :: x a -> STy a -> !(f x env a) -> Expr f x env a +deriving instance (forall ty. Show (x ty), forall env' ty. Show (f x env' ty)) => Show (Expr f x env t) + +data NoExt x env a + deriving (Show) -- | A (well-typed, well-scoped) expression using De Bruijn indices. The full -- 'Expr' type is parametrised on an indexed type of "additional info" (@x@); -- 'Ex' sets this to nothing. -type Ex = Expr (Const ()) +type Ex = Expr NoExt (Const ()) ext :: Const () a ext = Const () @@ -211,7 +218,7 @@ opt2 = \case OIDiv t -> STScal t OMod t -> STScal t -typeOf :: Expr x env t -> STy t +typeOf :: Expr f x env t -> STy t typeOf = \case EVar _ t _ -> t ELet _ _ e -> typeOf e @@ -266,7 +273,9 @@ typeOf = \case EError _ t _ -> t -extOf :: Expr x env t -> x t + EExt _ t _ -> t + +extOf :: Expr f x env t -> x t extOf = \case EVar x _ _ -> x ELet x _ _ -> x @@ -312,12 +321,13 @@ extOf = \case EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x + EExt x _ _ -> x -mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt :: TravExtEExt f => (forall a. x a -> x' a) -> Expr f x env t -> Expr f x' env t mapExt f = runIdentity . travExt (Identity . f) -{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} -travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt :: (Applicative f, TravExtEExt fe) + => (forall a. x a -> f (x' a)) -> Expr fe x env t -> f (Expr fe x' env t) travExt f = \case EVar x t i -> EVar <$> f x <*> pure t <*> pure i ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body @@ -363,8 +373,15 @@ travExt f = \case EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s + EExt x t v -> EExt <$> f x <*> pure t <*> travExtEExt f v + +class TravExtEExt fe where + travExtEExt :: Applicative f => (forall a. x a -> f (x' a)) -> fe x env t -> f (fe x' env t) + +instance TravExtEExt NoExt where + travExtEExt _ v = case v of {} -substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t +substInline :: Expr NoExt x env a -> Expr NoExt x (a : env) t -> Expr NoExt x env t substInline repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i @@ -374,14 +391,14 @@ subst0 repl = subst $ \_ t -> \case IZ -> repl IS i -> EVar ext t (IS i) -subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) - -> Expr x env t -> Expr x env' t +subst :: (forall a. x a -> STy a -> Idx env a -> Expr NoExt x env' a) + -> Expr NoExt x env t -> Expr NoExt x env' t subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId -subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr NoExt x env2 a) -> env' :> envOut - -> Expr x env t - -> Expr x envOut t + -> Expr NoExt x env t + -> Expr NoExt x envOut t subst' f w = \case EVar x t i -> f x t w i ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) @@ -428,13 +445,13 @@ subst' f w = \case EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where - sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr f x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr f x env2 t sinkF f' x' t w' = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i -weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr :: env :> env' -> Expr NoExt x env t -> Expr NoExt x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) class KnownScalTy t where knownScalTy :: SScalTy t @@ -495,7 +512,7 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict -cheapExpr :: Expr x env t -> Bool +cheapExpr :: Expr f x env t -> Bool cheapExpr = \case EVar{} -> True ENil{} -> True diff --git a/src/CHAD/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs index c1a1e77..3ecda3e 100644 --- a/src/CHAD/AST/Bindings.hs +++ b/src/CHAD/AST/Bindings.hs @@ -28,7 +28,7 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` -bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush :: Bindings (Expr NoExt x) env binds -> Expr NoExt x (Append binds env) t -> Bindings (Expr NoExt x) env (t : binds) bpush b e = b `BPush` (typeOf e, e) infixl `bpush` @@ -47,8 +47,8 @@ weakenBindings wf w (BPush b (t, x)) = in (BPush b' (t, wf w' x), WCopy w') weakenBindingsE :: env1 :> env2 - -> Bindings (Expr x) env1 binds - -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) + -> Bindings (Expr NoExt x) env1 binds + -> (Bindings (Expr NoExt x) env2 binds, Append binds env1 :> Append binds env2) weakenBindingsE = weakenBindings weakenExpr weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs index 46173d2..1dad758 100644 --- a/src/CHAD/AST/Count.hs +++ b/src/CHAD/AST/Count.hs @@ -338,15 +338,15 @@ envMaskPrj (EMRest b) _ = b envMaskPrj (_ `EMPush` b) IZ = b envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i -occCount :: Idx env a -> Expr x env t -> Occ +occCount :: Idx env a -> Expr NoExt x env t -> Occ occCount idx ex | Some env <- occCountAll ex = fst (occEnvPrj env idx) -occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll :: Expr NoExt x env t -> Some (OccEnv Occ env) occCountAll ex = occCountX SsFull ex $ \env _ -> Some env -pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr :: SList f env -> Expr NoExt x env t -> Ex env t pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) where fullOccEnv :: SList f env -> OccEnv () env env @@ -365,7 +365,7 @@ pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) -- occurrence counts. The callback reconstructs a new expression in an -- updated "response" environment. The response must be at least as large as -- the computed usages. -occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t +occCountX :: forall env t t' x r. Substruc t t' -> Expr NoExt x env t -> (forall env'. OccEnv Occ env env' -- response OccEnv must be at least as large as the OccEnv returned above -> (forall env''. OccEnv () env env'' -> Ex env'' t') @@ -885,7 +885,7 @@ occCountX initialS topexpr k = case topexpr of handleReduction :: t ~ TArr n (TScal t2) => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) - -> Expr x env (TArr (S n) (TScal t2)) + -> Expr NoExt x env (TArr (S n) (TScal t2)) -> r handleReduction reduce e | STArr (SS n) _ <- typeOf e = @@ -914,7 +914,7 @@ deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = case count of Zero -> k (SENo sub) _ -> k (SEYesR sub) -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv :: Subenv env env' -> Expr NoExt x env t -> Expr NoExt x env' t unsafeWeakenWithSubenv = \sub -> subst (\x t i -> case sinkViaSubenv i sub of Just i' -> EVar x t i' diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index 9ddcb35..b763efe 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -63,20 +63,20 @@ nameBaseForType _ = "x" genName' :: String -> M String genName' prefix = (prefix ++) . show <$> genId -genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr NoExt x env t -> M String genNameIfUsedIn' prefix ty idx ex | occCount idx ex == mempty = case ty of STNil -> return "()" _ -> return "_" | otherwise = genName' prefix -- TODO: let this return a type-tagged thing so that name environments are more typed than Const -genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn :: STy a -> Idx env a -> Expr NoExt x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t -pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO () +pprintExpr :: (KnownEnv env, PrettyX x) => Expr NoExt x env t -> IO () pprintExpr = putStrLn . ppExpr knownEnv -ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String +ppExpr :: PrettyX x => SList STy env -> Expr NoExt x env t -> String ppExpr senv e = render $ fst . flip runM 1 $ do val <- mkVal senv e' <- ppExpr' 0 val e @@ -94,7 +94,7 @@ ppExpr senv e = render $ fst . flip runM 1 $ do name <- genName' "arg" return (Const name `SCons` val) -ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc +ppExpr' :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc ppExpr' d val expr = case expr of EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr @@ -374,9 +374,9 @@ ppExpr' d val expr = case expr of EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) -ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc +ppExprLet :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc ppExprLet d val etop = do - let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) + let collect :: PrettyX x => SVal env -> Expr NoExt x env t -> M ([(String, Occ, ADoc)], ADoc) collect val' (ELet _ rhs body) = do let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body @@ -426,7 +426,7 @@ ppCommut :: Commutative -> String ppCommut Commut = "(C)" ppCommut Noncommut = "" -ppX :: PrettyX x => Expr x env t -> ADoc +ppX :: PrettyX x => Expr NoExt x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) data Fixity = Prefix | Infix diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs index 212cc7d..b637f88 100644 --- a/src/CHAD/Analysis/Identity.hs +++ b/src/CHAD/Analysis/Identity.hs @@ -63,15 +63,15 @@ validSplitEither (VIEither (Right v)) = (Nothing, Just v) validSplitEither (VIEither' v1 v2) = (Just v1, Just v2) -- | Symbolic partial evaluation. -identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t +identityAnalysis :: SList STy env -> Expr NoExt x env t -> Expr NoExt ValId env t identityAnalysis env term = runIdGen 0 $ do env' <- slistMapA genIds env snd <$> idana env' term -identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t +identityAnalysis' :: SList ValId env -> Expr NoExt x env t -> Expr NoExt ValId env t identityAnalysis' env term = snd (runIdGen 0 (idana env term)) -idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) +idana :: SList ValId env -> Expr NoExt x env t -> IdGen (ValId t, Expr NoExt ValId env t) idana env expr = case expr of EVar _ t i -> do let v = slistIdx env i diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs index bfa964b..eba3719 100644 --- a/src/CHAD/Drev.hs +++ b/src/CHAD/Drev.hs @@ -726,7 +726,7 @@ drev :: forall env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> Sparse (D2 t) sd - -> Expr ValId env t -> Ret env sto sd t + -> Expr NoExt ValId env t -> Ret env sto sd t drev des _ sd | isAbsent sd = \e -> Ret BTop @@ -774,7 +774,7 @@ drev des accumMap sd = \case (subenvNone (d2e (select SMerge des))) (ENil ext) - ELet _ (rhs :: Expr _ _ a) body + ELet _ (rhs :: Expr _ _ _ a) body | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs @@ -872,7 +872,7 @@ drev des accumMap sd = \case (EError ext (contribTupTy des sub') "inr<-dinl") (inj1 $ weakenExpr (WCopy WSink) e2)) - ECase _ e (a :: Expr _ _ t) b + ECase _ e (a :: Expr _ _ _ t) b | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge @@ -1041,7 +1041,7 @@ drev des accumMap sd = \case (subenvNone (d2e (select SMerge des))) (ENil ext) - EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ _ eltty) | SpArr @_ @sdElt sdElt <- sd , let eltty = typeOf ef , shty :: STy shty <- tTup (sreplicate ndim tIx) @@ -1081,7 +1081,7 @@ drev des accumMap sd = \case (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) e2) - EMap _ ef (earr :: Expr _ _ (TArr n a)) + EMap _ ef (earr :: Expr _ _ _ (TArr n a)) | SpArr sdElt <- sd , let STArr ndim t1 = typeOf earr t2 = typeOf ef -> @@ -1391,7 +1391,7 @@ deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> Sparse (D2s t) sd - -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) + -> Expr NoExt ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) deriv_extremum extremum des accumMap sd e | at@(STArr (SS n) t@(STScal st)) <- typeOf e , let at' = STArr n t @@ -1437,7 +1437,7 @@ drevScoped :: forall a s env sto sd t. => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> STy a -> Storage s -> Maybe (ValId a) -> Sparse (D2 t) sd - -> Expr ValId (a : env) t + -> Expr NoExt ValId (a : env) t -> RetScoped env sto a s sd t drevScoped des accumMap argty argsto argids sd expr = case argsto of SMerge @@ -1496,7 +1496,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) -> VarMap Int (D2AcE (Select env sto "accum")) -> (STy a, Storage s) -> Sparse (D2 t) dt - -> Expr ValId (a : env) t + -> Expr NoExt ValId (a : env) t -> (forall provars shbinds tape d2a'. SList STy provars -> Subenv (D2E (Select env sto "merge")) (D2E provars) @@ -1574,7 +1574,7 @@ drevLambda des accumMap (argty, argsto) sd origef k = prf1 _ _ SDiscr = Refl -- TODO: proper primal-only transform that doesn't depend on D1 = Id -drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal :: Descr env sto -> Expr NoExt x env t -> Ex (D1E env) (D1 t) drevPrimal des e | Refl <- d1Identity (typeOf e) , Refl <- d1eIdentity (descrList des) diff --git a/src/CHAD/Example/GMM.hs b/src/CHAD/Example/GMM.hs index 18641e8..2b2ac2b 100644 --- a/src/CHAD/Example/GMM.hs +++ b/src/CHAD/Example/GMM.hs @@ -112,7 +112,7 @@ gmmObjective wrong = fromNamed $ qmat q l = inline qmat' (SNil .$ q .$ l) in let_ #k2arr (unit #k2) $ - #k1 - + idx0 (sum1i (build1 #N $ #i :-> + + idx0 (sum1i (build1 #N $ #i :-> recompute $ logsumexp (build1 #K $ #k :-> #alpha ! pair nil #k + idx0 (sum1i (#Q .! #k)) diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs new file mode 100644 index 0000000..757667f --- /dev/null +++ b/src/CHAD/Fusion.hs @@ -0,0 +1,115 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Fusion where + +import Data.Dependent.Map (DMap) +-- import Data.Dependent.Map qualified as DMap +import Data.Functor.Const +import Data.Kind (Type) +import Numeric.Natural + +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Data + + +-- TODO: +-- A bunch of data types are defined here that should be able to express a +-- graph of loop nests. A graph is a straight-line program whose statements +-- are, in this case, loop nests. A loop nest corresponds to what fusion +-- normally calls a "cluster", but is here represented as, well, a loop nest. +-- +-- No unzipping is done here, as I don't think it is necessary: I haven't been +-- able to think of programs that get more fusion opportunities when unzipped +-- than when zipped. If any such programs exist, I in any case conjecture that +-- with a pre-pass that splits array operations that can be unzipped already at +-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1), +-- build n (\i -> E2))), all such fusion opportunities can be recovered. If +-- this conjecture is false, some reorganisation may be required. +-- +-- Next steps, perhaps: +-- 1. Express a build operation as a LoopNest, not from the EBuild constructor +-- specifically but its fields. It will have a single output, and its args +-- will be its list of free variables. +-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out. +-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of +-- a large expression that contain only "simple" AST constructors, and +-- replaces those subterms with an EExt containing that graph. In this +-- construction process, EBuild and ESum1Inner should be replaced with +-- FLoop. +-- 4. Implement fusion somehow on graphs! +-- 5. Add an AST constructor for a loop nest (which most of the modules throw +-- an error on, except Count, Simplify and Compile -- good luck with Count), +-- and compile that to an actual C loop nest. +-- 6. Extend to other cool operations like EFold1InnerD1 + + +type FEx = Expr FGraph (Const ()) + +type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type +data FGraph x env t where + FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t + +data Node env t where + FFreeVar :: STy t -> Idx env t -> Node env t + FLoop :: SList NodeId args + -> SList STy outs + -> LoopNest args outs + -> Tuple (Idx outs) t + -> Node env t + +data NodeId t = NodeId Natural (STy t) + deriving (Show) + +data Tuple f t where + TupNil :: Tuple f TNil + TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b) + TupSingle :: f t -> Tuple f t +deriving instance (forall a. Show (f a)) => Show (Tuple f t) + +data LoopNest args outs where + Inner :: Bindings Ex args bs + -> SList (Idx (Append bs args)) outs + -> LoopNest args outs + -- this should be able to express a simple nesting of builds and sums. + Layer :: Bindings Ex args bs1 + -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) + -> LoopNest (TIx : Append bs1 args) loopouts + -> Partition BuildUp RedSum loopouts mapouts sumouts + -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 + -> SList (Idx (Append bs2 args)) outs + -> LoopNest args (Append outs mapouts) + +type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type +data Partition f1 f2 ts ts1 ts2 where + PNil :: Partition f1 f2 '[] '[] '[] + Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 + Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) + +data BuildUp t t' where + BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) + +data RedSum t t' where + RedSum :: SMTy t -> RedSum t t + +-- type family Unzip t where +-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) +-- Unzip (TArr n t) = UnzipA n t + +-- type family UnzipA n t where +-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b) +-- UnzipA n t = TArr n t + +-- data Zipping ut t where +-- ZId :: Zipping t t +-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b) +-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) +-- deriving instance Show (Zipping ut t) + + diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs index ea253d6..a09effc 100644 --- a/src/CHAD/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -364,7 +364,7 @@ simplify'Rec = \case -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. -hasAdds :: Expr x env t -> Bool +hasAdds :: Expr NoExt x env t -> Bool hasAdds = \case EVar _ _ _ -> False ELet _ rhs body -> hasAdds rhs || hasAdds body |
