{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const import Data.Functor.Identity import Data.Int (Int64) import Data.Kind (Type) import Array import AST.Accum import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types import Data -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. -- -- All the monoid operations are unsupposed as the input to CHAD, and are -- intended to be eliminated after simplification, so that the input program as -- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus EVar :: x t -> STy t -> Idx env t -> Expr x env t ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t -- base types EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) EFst :: x a -> Expr x env (TPair a b) -> Expr x env a ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b ENil :: x TNil -> Expr x env TNil EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: -- an implementation is allowed to parallelise this thing and store the b -- values in some implementation-defined order. -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative -> Expr x (TPair t1 t1 : env) (TPair t1 b) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) -- Reverse derivative of EFold1Inner. The contributions to the initial -- element are not yet added together here; we assume a later fusion system -- does that for us. EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 -> Expr x env (TArr n t2) -- incoming cotangent -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- custom derivatives -- 'b' is the part of the input of the operation that derivatives should -- be backpropagated to; 'a' is the inactive part. The dual field of -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. -- No accumulators are allowed inside a, b and tape. This restriction is -- currently not used very much, so could be relaxed in the future; be sure -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x [b, a] t -- ^ regular operation -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative -> Expr x env a -> Expr x env b -> Expr x env t -- fake halfway checkpointing ERecompute :: x t -> Expr x env t -> Expr x env t -- accumulation effect on monoids -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) -- The 'Sparse' here is eliminated to dense by UnMonoid. EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t -- interface of abstract monoidal types ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- partiality EError :: x a -> STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) type Ex = Expr (Const ()) ext :: Const () a ext = Const () data Commutative = Commut | Noncommut deriving (Show) type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right ORound64 :: SOp (TScal TF64) (TScal TI64) OToFl64 :: SOp (TScal TI64) (TScal TF64) ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) deriving instance Show (SOp a t) opt1 :: SOp a t -> STy a opt1 = \case OAdd t -> STPair (STScal t) (STScal t) OMul t -> STPair (STScal t) (STScal t) ONeg t -> STScal t OLt t -> STPair (STScal t) (STScal t) OLe t -> STPair (STScal t) (STScal t) OEq t -> STPair (STScal t) (STScal t) ONot -> STScal STBool OAnd -> STPair (STScal STBool) (STScal STBool) OOr -> STPair (STScal STBool) (STScal STBool) OIf -> STScal STBool ORound64 -> STScal STF64 OToFl64 -> STScal STI64 ORecip t -> STScal t OExp t -> STScal t OLog t -> STScal t OIDiv t -> STPair (STScal t) (STScal t) OMod t -> STPair (STScal t) (STScal t) opt2 :: SOp a t -> STy t opt2 = \case OAdd t -> STScal t OMul t -> STScal t ONeg t -> STScal t OLt _ -> STScal STBool OLe _ -> STScal STBool OEq _ -> STScal STBool ONot -> STScal STBool OAnd -> STScal STBool OOr -> STScal STBool OIf -> STEither STNil STNil ORound64 -> STScal STI64 OToFl64 -> STScal STF64 ORecip t -> STScal t OExp t -> STScal t OLog t -> STScal t OIDiv t -> STScal t OMod t -> STScal t typeOf :: Expr x env t -> STy t typeOf = \case EVar _ t _ -> t ELet _ _ e -> typeOf e EPair _ a b -> STPair (typeOf a) (typeOf b) EFst _ e | STPair t _ <- typeOf e -> t ESnd _ e | STPair _ t <- typeOf e -> t ENil _ -> STNil EInl _ t2 e -> STEither (typeOf e) t2 EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a ENothing _ t -> STMaybe t EJust _ e -> STMaybe (typeOf e) EMaybe _ e _ _ -> typeOf e ELNil _ t1 t2 -> STLEither t1 t2 ELInl _ t2 e -> STLEither (typeOf e) t2 ELInr _ t1 e -> STLEither t1 (typeOf e) ELCase _ _ a _ _ -> typeOf a EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t EIdx _ e _ | STArr _ t <- typeOf e -> t EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op ECustom _ _ _ _ e _ _ _ _ -> typeOf e ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t EError _ t _ -> t extOf :: Expr x env t -> x t extOf = \case EVar x _ _ -> x ELet x _ _ -> x EPair x _ _ -> x EFst x _ -> x ESnd x _ -> x ENil x -> x EInl x _ _ -> x EInr x _ _ -> x ECase x _ _ _ -> x ENothing x _ -> x EJust x _ -> x EMaybe x _ _ _ -> x ELNil x _ _ -> x ELInl x _ _ -> x ELInr x _ _ -> x ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x EMap x _ _ -> x EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x EReshape x _ _ _ -> x EZip x _ _ -> x EFold1InnerD1 x _ _ _ _ -> x EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x EIdx x _ _ -> x EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x ERecompute x _ -> x EWith x _ _ _ -> x EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t mapExt f = runIdentity . travExt (Identity . f) {-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) travExt f = \case EVar x t i -> EVar <$> f x <*> pure t <*> pure i ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b EFst x e -> EFst <$> f x <*> travExt f e ESnd x e -> ESnd <$> f x <*> travExt f e ENil x -> ENil <$> f x EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b ENothing x t -> ENothing <$> f x <*> pure t EJust x e -> EJust <$> f x <*> travExt f e EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e EUnit x e -> EUnit <$> f x <*> travExt f e EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es EShape x e -> EShape <$> f x <*> travExt f e EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t substInline repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t subst0 repl = subst $ \_ t -> \case IZ -> repl IS i -> EVar ext t (IS i) subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) -> Expr x env t -> Expr x env' t subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) -> env' :> envOut -> Expr x env t -> Expr x envOut t subst' f w = \case EVar x t i -> f x t w i ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) EPair x a b -> EPair x (subst' f w a) (subst' f w b) EFst x e -> EFst x (subst' f w e) ESnd x e -> ESnd x (subst' f w e) ENil x -> ENil x EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) ENothing x t -> ENothing x t EJust x e -> EJust x (subst' f w e) EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) ELNil x t1 t2 -> ELNil x t1 t2 ELInl x t e -> ELInl x t (subst' f w e) ELInr x t e -> ELInr x t (subst' f w e) ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) EZip x a b -> EZip x (subst' f w a) (subst' f w b) EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t sinkF f' x' t w' = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 instance KnownScalTy TF32 where knownScalTy = STF32 instance KnownScalTy TF64 where knownScalTy = STF64 instance KnownScalTy TBool where knownScalTy = STBool class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy class KnownMTy t where knownMTy :: SMTy t instance KnownMTy TNil where knownMTy = SMTNil instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- smtyKnown t = Dict smtyKnown :: SMTy t -> Dict (KnownMTy t) smtyKnown SMTNil = Dict smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict sscaltyKnown STI64 = Dict sscaltyKnown STF32 = Dict sscaltyKnown STF64 = Dict sscaltyKnown STBool = Dict envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True ENil{} -> True EConst{} -> True EFst _ e -> cheapExpr e ESnd _ e -> cheapExpr e EUnit _ e -> cheapExpr e _ -> False eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) (EFst ext arg) eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) eidxEq SZ _ _ = EConst ext STBool True eidxEq (SS SZ) a b = EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) eidxEq (SS n) a b | let ty = tTup (sreplicate (SS n) tIx) = ELet ext a $ ELet ext (weakenExpr WSink b) $ EOp ext OAnd $ EPair ext (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) (ESnd ext (EVar ext ty IZ)))) (eidxEq n (EFst ext (EVar ext ty (IS IZ))) (EFst ext (EVar ext ty IZ))) emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) emap f arr | STArr _ t <- typeOf arr , Dict <- styKnown t = EMap ext f arr ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) ezipWith f arr1 arr2 | STArr _ t1 <- typeOf arr1 , STArr _ t2 <- typeOf arr2 , Dict <- styKnown t1 , Dict <- styKnown t2 = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) IS (IS i) -> EVar ext t (IS i)) f) (EZip ext arr1 arr2) ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) ezip = EZip ext eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) -- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) eshapeEmpty SZ _ = EConst ext STBool False eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) eshapeEmpty (SS n) e = ELet ext e $ EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) eshapeConst ShNil = ENil ext eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx eshapeProd SZ _ = EConst ext STI64 1 eshapeProd (SS SZ) e = ESnd ext e eshapeProd (SS n) e = eunPair e $ \_ e1 e2 -> EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) eflatten e = let STArr n _ = typeOf e in elet e $ EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) -- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) -- ezeroD2 t ezi = EZero ext (d2M t) ezi -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r eunPair (EPair _ e1 e2) k = k WId e1 e2 eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) eunPair e k = elet e $ k WSink (EFst ext (evar IZ)) (ESnd ext (evar IZ)) efst :: Ex env (TPair a b) -> Ex env a efst (EPair _ e1 _) = e1 efst e = EFst ext e esnd :: Ex env (TPair a b) -> Ex env b esnd (EPair _ _ e2) = e2 esnd e = ESnd ext e elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) = if cheapExpr rhs then substInline rhs body else ELet ext rhs body -- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) use :: Ex env a -> Ex env b -> Ex env b use a b = elet a $ weakenExpr WSink b emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b emaybe e a b | STMaybe t <- typeOf e , Dict <- styKnown t = EMaybe ext a b e ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c ecase e a b | STEither t1 t2 <- typeOf e , Dict <- styKnown t1 , Dict <- styKnown t2 = ECase ext e a b elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c elcase e a b c | STLEither t1 t2 <- typeOf e , Dict <- styKnown t1 , Dict <- styKnown t2 = ELCase ext e a b c evar :: KnownTy a => Idx env a -> Ex env a evar = EVar ext knownTy makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) where -- invariant: expression argument is duplicable go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) go SMTNil _ = ENil ext go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) go SMTLEither{} _ = ENil ext go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext splitSparsePair :: -- given a sparsity STy (TPair a b) -> Sparse (TPair a b) t' -> (forall a' b'. -- I give you back two sparsities for a and b Sparse a a' -> Sparse b b' -- furthermore, I tell you that either your t' is already this (a', b') pair... -> Either (t' :~: TPair a' b') -- or I tell you how to construct a' and b' from t', given an actual t' (forall r' env. Idx env t' -> (forall env'. (forall c. Ex env' c -> Ex env c) -> Ex env' a' -> Ex env' b' -> r') -> r') -> r) -> r splitSparsePair _ SpAbsent k = k SpAbsent SpAbsent $ Right $ \_ k2 -> k2 id (ENil ext) (ENil ext) splitSparsePair _ (SpPair s1 s2) k1 = k1 s1 s2 $ Left Refl splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> k2 (elet $ emaybe (EVar ext (STMaybe (applySparse s t)) i) (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) splitSparsePair _ (SpSparse SpAbsent) k = k SpAbsent SpAbsent $ Right $ \_ k2 -> k2 id (ENil ext) (ENil ext) -- -- TODO: having to handle sparse-of-sparse at all is ridiculous splitSparsePair t (SpSparse (SpSparse s)) k = splitSparsePair t (SpSparse s) $ \s1 s2 eres -> k s1 s2 $ Right $ \i k2 -> case eres of Left refl -> case refl of {} Right f -> f IZ $ \wrap e1 e2 -> k2 (\body -> elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) (ENothing ext (applySparse s t)) (evar IZ)) $ wrap body) e1 e2