diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/AST.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/AST.hs')
| -rw-r--r-- | src/AST.hs | 705 |
1 files changed, 0 insertions, 705 deletions
diff --git a/src/AST.hs b/src/AST.hs deleted file mode 100644 index ca6cdd1..0000000 --- a/src/AST.hs +++ /dev/null @@ -1,705 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where - -import Data.Functor.Const -import Data.Functor.Identity -import Data.Int (Int64) -import Data.Kind (Type) - -import Array -import AST.Accum -import AST.Sparse.Types -import AST.Types -import AST.Weaken -import CHAD.Types -import Data - - --- General assumption: head of the list (whatever way it is associated) is the --- inner variable / inner array dimension. In pretty printing, the inner --- variable / inner dimension is printed on the _right_. --- --- All the monoid operations are unsupposed as the input to CHAD, and are --- intended to be eliminated after simplification, so that the input program as --- well as the output program do not contain these constructors. --- TODO: ensure this by a "stage" type parameter. -type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type -data Expr x env t where - -- lambda calculus - EVar :: x t -> STy t -> Idx env t -> Expr x env t - ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t - - -- base types - EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) - EFst :: x a -> Expr x env (TPair a b) -> Expr x env a - ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b - ENil :: x TNil -> Expr x env TNil - EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) - EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) - ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) - EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) - EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b - - -- array operations - EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) - -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) - ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) - EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) - EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) - - -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: - -- an implementation is allowed to parallelise this thing and store the b - -- values in some implementation-defined order. - -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. - EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (TPair t1 t1 : env) (TPair t1 b) - -> Expr x env t1 - -> Expr x env (TArr (S n) t1) - -> Expr x env (TPair (TArr n t1) -- normal primal fold output - (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) - -- Reverse derivative of EFold1Inner. The contributions to the initial - -- element are not yet added together here; we assume a later fusion system - -- does that for us. - EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative - -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 - -> Expr x env (TArr n t2) -- incoming cotangent - -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array - - -- expression operations - EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) - EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t - EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t - EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) - EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t - - -- custom derivatives - -- 'b' is the part of the input of the operation that derivatives should - -- be backpropagated to; 'a' is the inactive part. The dual field of - -- ECustom does not allow a derivative to be generated for 'a', and hence - -- none is propagated. - -- No accumulators are allowed inside a, b and tape. This restriction is - -- currently not used very much, so could be relaxed in the future; be sure - -- to check this requirement whenever it is necessary for soundness! - ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x [b, a] t -- ^ regular operation - -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative - -> Expr x env a -> Expr x env b - -> Expr x env t - - -- fake halfway checkpointing - ERecompute :: x t -> Expr x env t -> Expr x env t - - -- accumulation effect on monoids - -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it - -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not - -- need to create any zeros. - EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - -- The 'Sparse' here is eliminated to dense by UnMonoid. - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil - - -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t - EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t - EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t - - -- interface of abstract monoidal types - ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) - ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) - ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) - ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - - -- partiality - EError :: x a -> STy a -> String -> Expr x env a -deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) - -type Ex = Expr (Const ()) - -ext :: Const () a -ext = Const () - -data Commutative = Commut | Noncommut - deriving (Show) - -type SOp :: Ty -> Ty -> Type -data SOp a t where - OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - ONot :: SOp (TScal TBool) (TScal TBool) - OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right - ORound64 :: SOp (TScal TF64) (TScal TI64) - OToFl64 :: SOp (TScal TI64) (TScal TF64) - ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) -deriving instance Show (SOp a t) - -opt1 :: SOp a t -> STy a -opt1 = \case - OAdd t -> STPair (STScal t) (STScal t) - OMul t -> STPair (STScal t) (STScal t) - ONeg t -> STScal t - OLt t -> STPair (STScal t) (STScal t) - OLe t -> STPair (STScal t) (STScal t) - OEq t -> STPair (STScal t) (STScal t) - ONot -> STScal STBool - OAnd -> STPair (STScal STBool) (STScal STBool) - OOr -> STPair (STScal STBool) (STScal STBool) - OIf -> STScal STBool - ORound64 -> STScal STF64 - OToFl64 -> STScal STI64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STPair (STScal t) (STScal t) - OMod t -> STPair (STScal t) (STScal t) - -opt2 :: SOp a t -> STy t -opt2 = \case - OAdd t -> STScal t - OMul t -> STScal t - ONeg t -> STScal t - OLt _ -> STScal STBool - OLe _ -> STScal STBool - OEq _ -> STScal STBool - ONot -> STScal STBool - OAnd -> STScal STBool - OOr -> STScal STBool - OIf -> STEither STNil STNil - ORound64 -> STScal STI64 - OToFl64 -> STScal STF64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STScal t - OMod t -> STScal t - -typeOf :: Expr x env t -> STy t -typeOf = \case - EVar _ t _ -> t - ELet _ _ e -> typeOf e - - EPair _ a b -> STPair (typeOf a) (typeOf b) - EFst _ e | STPair t _ <- typeOf e -> t - ESnd _ e | STPair _ t <- typeOf e -> t - ENil _ -> STNil - EInl _ t2 e -> STEither (typeOf e) t2 - EInr _ t1 e -> STEither t1 (typeOf e) - ECase _ _ a _ -> typeOf a - ENothing _ t -> STMaybe t - EJust _ e -> STMaybe (typeOf e) - EMaybe _ e _ _ -> typeOf e - ELNil _ t1 t2 -> STLEither t1 t2 - ELInl _ t2 e -> STLEither (typeOf e) t2 - ELInr _ t1 e -> STLEither t1 (typeOf e) - ELCase _ _ a _ _ -> typeOf a - - EConstArr _ n t _ -> STArr n (STScal t) - EBuild _ n _ e -> STArr n (typeOf e) - EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) - EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t - ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EUnit _ e -> STArr SZ (typeOf e) - EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t - EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t - EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) - - EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) - EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) - - EConst _ t _ -> STScal t - EIdx0 _ e | STArr _ t <- typeOf e -> t - EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ e _ | STArr _ t <- typeOf e -> t - EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) - EOp _ op _ -> opt2 op - - ECustom _ _ _ _ e _ _ _ _ -> typeOf e - ERecompute _ e -> typeOf e - - EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ _ -> STNil - - EZero _ t _ -> fromSMTy t - EDeepZero _ t _ -> fromSMTy t - EPlus _ t _ _ -> fromSMTy t - EOneHot _ t _ _ _ -> fromSMTy t - - EError _ t _ -> t - -extOf :: Expr x env t -> x t -extOf = \case - EVar x _ _ -> x - ELet x _ _ -> x - EPair x _ _ -> x - EFst x _ -> x - ESnd x _ -> x - ENil x -> x - EInl x _ _ -> x - EInr x _ _ -> x - ECase x _ _ _ -> x - ENothing x _ -> x - EJust x _ -> x - EMaybe x _ _ _ -> x - ELNil x _ _ -> x - ELInl x _ _ -> x - ELInr x _ _ -> x - ELCase x _ _ _ _ -> x - EConstArr x _ _ _ -> x - EBuild x _ _ _ -> x - EMap x _ _ -> x - EFold1Inner x _ _ _ _ -> x - ESum1Inner x _ -> x - EUnit x _ -> x - EReplicate1Inner x _ _ -> x - EMaximum1Inner x _ -> x - EMinimum1Inner x _ -> x - EReshape x _ _ _ -> x - EZip x _ _ -> x - EFold1InnerD1 x _ _ _ _ -> x - EFold1InnerD2 x _ _ _ _ -> x - EConst x _ _ -> x - EIdx0 x _ -> x - EIdx1 x _ _ -> x - EIdx x _ _ -> x - EShape x _ -> x - EOp x _ _ -> x - ECustom x _ _ _ _ _ _ _ _ -> x - ERecompute x _ -> x - EWith x _ _ _ -> x - EAccum x _ _ _ _ _ _ -> x - EZero x _ _ -> x - EDeepZero x _ _ -> x - EPlus x _ _ _ -> x - EOneHot x _ _ _ _ -> x - EError x _ _ -> x - -mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = runIdentity . travExt (Identity . f) - -{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} -travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) -travExt f = \case - EVar x t i -> EVar <$> f x <*> pure t <*> pure i - ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body - EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b - EFst x e -> EFst <$> f x <*> travExt f e - ESnd x e -> ESnd <$> f x <*> travExt f e - ENil x -> ENil <$> f x - EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e - EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e - ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b - ENothing x t -> ENothing <$> f x <*> pure t - EJust x e -> EJust <$> f x <*> travExt f e - EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e - ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 - ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e - ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e - ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c - EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a - EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b - EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b - EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e - EUnit x e -> EUnit <$> f x <*> travExt f e - EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b - EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e - EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e - EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b - EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b - EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EConst x t v -> EConst <$> f x <*> pure t <*> pure v - EIdx0 x e -> EIdx0 <$> f x <*> travExt f e - EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b - EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es - EShape x e -> EShape <$> f x <*> travExt f e - EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e - ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 - ERecompute x e -> ERecompute <$> f x <*> travExt f e - EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 - EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 - EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e - EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e - EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b - EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b - EError x t s -> EError <$> f x <*> pure t <*> pure s - -substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t -substInline repl = - subst $ \x t -> \case IZ -> repl - IS i -> EVar x t i - -subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t -subst0 repl = - subst $ \_ t -> \case IZ -> repl - IS i -> EVar ext t (IS i) - -subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) - -> Expr x env t -> Expr x env' t -subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId - -subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) - -> env' :> envOut - -> Expr x env t - -> Expr x envOut t -subst' f w = \case - EVar x t i -> f x t w i - ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) - EPair x a b -> EPair x (subst' f w a) (subst' f w b) - EFst x e -> EFst x (subst' f w e) - ESnd x e -> ESnd x (subst' f w e) - ENil x -> ENil x - EInl x t e -> EInl x t (subst' f w e) - EInr x t e -> EInr x t (subst' f w e) - ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) - ENothing x t -> ENothing x t - EJust x e -> EJust x (subst' f w e) - EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) - ELNil x t1 t2 -> ELNil x t1 t2 - ELInl x t e -> ELInl x t (subst' f w e) - ELInr x t e -> ELInr x t (subst' f w e) - ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) - ESum1Inner x e -> ESum1Inner x (subst' f w e) - EUnit x e -> EUnit x (subst' f w e) - EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) - EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) - EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) - EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) - EZip x a b -> EZip x (subst' f w a) (subst' f w b) - EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) - EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (subst' f w e) - EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) - EShape x e -> EShape x (subst' f w e) - EOp x op e -> EOp x op (subst' f w e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - ERecompute x e -> ERecompute x (subst' f w e) - EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) - EZero x t e -> EZero x t (subst' f w e) - EDeepZero x t e -> EDeepZero x t (subst' f w e) - EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) - EError x t s -> EError x t s - where - sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t - sinkF f' x' t w' = \case - IZ -> EVar x' t (w' @> IZ) - IS i -> f' x' t (WPop w') i - -weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t -weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) - -class KnownScalTy t where knownScalTy :: SScalTy t -instance KnownScalTy TI32 where knownScalTy = STI32 -instance KnownScalTy TI64 where knownScalTy = STI64 -instance KnownScalTy TF32 where knownScalTy = STF32 -instance KnownScalTy TF64 where knownScalTy = STF64 -instance KnownScalTy TBool where knownScalTy = STBool - -class KnownTy t where knownTy :: STy t -instance KnownTy TNil where knownTy = STNil -instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy -instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy -instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy -instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy - -class KnownMTy t where knownMTy :: SMTy t -instance KnownMTy TNil where knownMTy = SMTNil -instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy -instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy -instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy -instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy -instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy - -class KnownEnv env where knownEnv :: SList STy env -instance KnownEnv '[] where knownEnv = SNil -instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv - -styKnown :: STy t -> Dict (KnownTy t) -styKnown STNil = Dict -styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STMaybe t) | Dict <- styKnown t = Dict -styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict -styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- smtyKnown t = Dict - -smtyKnown :: SMTy t -> Dict (KnownMTy t) -smtyKnown SMTNil = Dict -smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict -smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict -smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict - -sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) -sscaltyKnown STI32 = Dict -sscaltyKnown STI64 = Dict -sscaltyKnown STF32 = Dict -sscaltyKnown STF64 = Dict -sscaltyKnown STBool = Dict - -envKnown :: SList STy env -> Dict (KnownEnv env) -envKnown SNil = Dict -envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict - -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - EUnit _ e -> cheapExpr e - _ -> False - -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) - -ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = - EBuild ext (SS n) (EPair ext sh size) $ - let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ - in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) - (EFst ext arg) - -eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eidxEq SZ _ _ = EConst ext STBool True -eidxEq (SS SZ) a b = - EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) -eidxEq (SS n) a b - | let ty = tTup (sreplicate (SS n) tIx) - = ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EOp ext OAnd $ EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) - (ESnd ext (EVar ext ty IZ)))) - (eidxEq n (EFst ext (EVar ext ty (IS IZ))) - (EFst ext (EVar ext ty IZ))) - -emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr - | STArr _ t <- typeOf arr - , Dict <- styKnown t - = EMap ext f arr - -ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 - | STArr _ t1 <- typeOf arr1 - , STArr _ t2 <- typeOf arr2 - , Dict <- styKnown t1 - , Dict <- styKnown t2 - = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) - IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) - IS (IS i) -> EVar ext t (IS i)) - f) - (EZip ext arr1 arr2) - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip = EZip ext - -eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a -eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) - --- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). -eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eshapeEmpty SZ _ = EConst ext STBool False -eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) -eshapeEmpty (SS n) e = - ELet ext e $ - EOp ext OAnd (EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) - (EConst ext STI64 0))) - (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) - -eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) -eshapeConst ShNil = ENil ext -eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) - -eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -eshapeProd SZ _ = EConst ext STI64 1 -eshapeProd (SS SZ) e = ESnd ext e -eshapeProd (SS n) e = - eunPair e $ \_ e1 e2 -> - EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) - -eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) -eflatten e = - let STArr n _ = typeOf e - in elet e $ - EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) - --- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) --- ezeroD2 t ezi = EZero ext (d2M t) ezi - --- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil --- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea - --- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) --- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev - -eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r -eunPair (EPair _ e1 e2) k = k WId e1 e2 -eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) -eunPair e k = - elet e $ - k WSink - (EFst ext (evar IZ)) - (ESnd ext (evar IZ)) - -efst :: Ex env (TPair a b) -> Ex env a -efst (EPair _ e1 _) = e1 -efst e = EFst ext e - -esnd :: Ex env (TPair a b) -> Ex env b -esnd (EPair _ _ e2) = e2 -esnd e = ESnd ext e - -elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b -elet rhs body - | Dict <- styKnown (typeOf rhs) - = if cheapExpr rhs - then substInline rhs body - else ELet ext rhs body - --- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) -use :: Ex env a -> Ex env b -> Ex env b -use a b = elet a $ weakenExpr WSink b - -emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b -emaybe e a b - | STMaybe t <- typeOf e - , Dict <- styKnown t - = EMaybe ext a b e - -ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c -ecase e a b - | STEither t1 t2 <- typeOf e - , Dict <- styKnown t1 - , Dict <- styKnown t2 - = ECase ext e a b - -elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c -elcase e a b c - | STLEither t1 t2 <- typeOf e - , Dict <- styKnown t1 - , Dict <- styKnown t2 - = ELCase ext e a b c - -evar :: KnownTy a => Idx env a -> Ex env a -evar = EVar ext knownTy - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext - -splitSparsePair - :: -- given a sparsity - STy (TPair a b) -> Sparse (TPair a b) t' - -> (forall a' b'. - -- I give you back two sparsities for a and b - Sparse a a' -> Sparse b b' - -- furthermore, I tell you that either your t' is already this (a', b') pair... - -> Either - (t' :~: TPair a' b') - -- or I tell you how to construct a' and b' from t', given an actual t' - (forall r' env. - Idx env t' - -> (forall env'. - (forall c. Ex env' c -> Ex env c) - -> Ex env' a' -> Ex env' b' -> r') - -> r') - -> r) - -> r -splitSparsePair _ SpAbsent k = - k SpAbsent SpAbsent $ Right $ \_ k2 -> - k2 id (ENil ext) (ENil ext) -splitSparsePair _ (SpPair s1 s2) k1 = - k1 s1 s2 $ Left Refl -splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = - let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in - k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> - k2 (elet $ - emaybe (EVar ext (STMaybe (applySparse s t)) i) - (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) - (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) - (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) - -splitSparsePair _ (SpSparse SpAbsent) k = - k SpAbsent SpAbsent $ Right $ \_ k2 -> - k2 id (ENil ext) (ENil ext) --- -- TODO: having to handle sparse-of-sparse at all is ridiculous -splitSparsePair t (SpSparse (SpSparse s)) k = - splitSparsePair t (SpSparse s) $ \s1 s2 eres -> - k s1 s2 $ Right $ \i k2 -> - case eres of - Left refl -> case refl of {} - Right f -> - f IZ $ \wrap e1 e2 -> - k2 (\body -> - elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) - (ENothing ext (applySparse s t)) - (evar IZ)) $ - wrap body) - e1 e2 |
