{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleInstances #-} module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const import Data.Functor.Identity import Data.Kind (Type) import Array import AST.Accum import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types import Data -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. -- -- All the monoid operations are unsupposed as the input to CHAD, and are -- intended to be eliminated after simplification, so that the input program as -- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus EVar :: x t -> STy t -> Idx env t -> Expr x env t ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t -- base types EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) EFst :: x a -> Expr x env (TPair a b) -> Expr x env a ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b ENil :: x TNil -> Expr x env TNil EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- custom derivatives -- 'b' is the part of the input of the operation that derivatives should -- be backpropagated to; 'a' is the inactive part. The dual field of -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x [b, a] t -- ^ regular operation -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative -> Expr x env a -> Expr x env b -> Expr x env t -- fake halfway checkpointing ERecompute :: x t -> Expr x env t -> Expr x env t -- accumulation effect on monoids -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) -- The 'Sparse' here is eliminated to dense by UnMonoid. EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t -- interface of abstract monoidal types ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- partiality EError :: x a -> STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) type Ex = Expr (Const ()) ext :: Const () a ext = Const () data Commutative = Commut | Noncommut deriving (Show) type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right ORound64 :: SOp (TScal TF64) (TScal TI64) OToFl64 :: SOp (TScal TI64) (TScal TF64) ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) deriving instance Show (SOp a t) opt1 :: SOp a t -> STy a opt1 = \case OAdd t -> STPair (STScal t) (STScal t) OMul t -> STPair (STScal t) (STScal t) ONeg t -> STScal t OLt t -> STPair (STScal t) (STScal t) OLe t -> STPair (STScal t) (STScal t) OEq t -> STPair (STScal t) (STScal t) ONot -> STScal STBool OAnd -> STPair (STScal STBool) (STScal STBool) OOr -> STPair (STScal STBool) (STScal STBool) OIf -> STScal STBool ORound64 -> STScal STF64 OToFl64 -> STScal STI64 ORecip t -> STScal t OExp t -> STScal t OLog t -> STScal t OIDiv t -> STPair (STScal t) (STScal t) OMod t -> STPair (STScal t) (STScal t) opt2 :: SOp a t -> STy t opt2 = \case OAdd t -> STScal t OMul t -> STScal t ONeg t -> STScal t OLt _ -> STScal STBool OLe _ -> STScal STBool OEq _ -> STScal STBool ONot -> STScal STBool OAnd -> STScal STBool OOr -> STScal STBool OIf -> STEither STNil STNil ORound64 -> STScal STI64 OToFl64 -> STScal STF64 ORecip t -> STScal t OExp t -> STScal t OLog t -> STScal t OIDiv t -> STScal t OMod t -> STScal t typeOf :: Expr x env t -> STy t typeOf = \case EVar _ t _ -> t ELet _ _ e -> typeOf e EPair _ a b -> STPair (typeOf a) (typeOf b) EFst _ e | STPair t _ <- typeOf e -> t ESnd _ e | STPair _ t <- typeOf e -> t ENil _ -> STNil EInl _ t2 e -> STEither (typeOf e) t2 EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a ENothing _ t -> STMaybe t EJust _ e -> STMaybe (typeOf e) EMaybe _ e _ _ -> typeOf e ELNil _ t1 t2 -> STLEither t1 t2 ELInl _ t2 e -> STLEither (typeOf e) t2 ELInr _ t1 e -> STLEither t1 (typeOf e) ELCase _ _ a _ _ -> typeOf a EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t EIdx _ e _ | STArr _ t <- typeOf e -> t EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op ECustom _ _ _ _ e _ _ _ _ -> typeOf e ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t EError _ t _ -> t extOf :: Expr x env t -> x t extOf = \case EVar x _ _ -> x ELet x _ _ -> x EPair x _ _ -> x EFst x _ -> x ESnd x _ -> x ENil x -> x EInl x _ _ -> x EInr x _ _ -> x ECase x _ _ _ -> x ENothing x _ -> x EJust x _ -> x EMaybe x _ _ _ -> x ELNil x _ _ -> x ELInl x _ _ -> x ELInr x _ _ -> x ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x EIdx x _ _ -> x EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x ERecompute x _ -> x EWith x _ _ _ -> x EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t mapExt f = runIdentity . travExt (Identity . f) {-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) travExt f = \case EVar x t i -> EVar <$> f x <*> pure t <*> pure i ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b EFst x e -> EFst <$> f x <*> travExt f e ESnd x e -> ESnd <$> f x <*> travExt f e ENil x -> ENil <$> f x EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b ENothing x t -> ENothing <$> f x <*> pure t EJust x e -> EJust <$> f x <*> travExt f e EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e EUnit x e -> EUnit <$> f x <*> travExt f e EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es EShape x e -> EShape <$> f x <*> travExt f e EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t substInline repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t subst0 repl = subst $ \_ t -> \case IZ -> repl IS i -> EVar ext t (IS i) subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) -> Expr x env t -> Expr x env' t subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) -> env' :> envOut -> Expr x env t -> Expr x envOut t subst' f w = \case EVar x t i -> f x t w i ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) EPair x a b -> EPair x (subst' f w a) (subst' f w b) EFst x e -> EFst x (subst' f w e) ESnd x e -> ESnd x (subst' f w e) ENil x -> ENil x EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) ENothing x t -> ENothing x t EJust x e -> EJust x (subst' f w e) EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) ELNil x t1 t2 -> ELNil x t1 t2 ELInl x t e -> ELInl x t (subst' f w e) ELInr x t e -> ELInr x t (subst' f w e) ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t sinkF f' x' t w' = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 instance KnownScalTy TF32 where knownScalTy = STF32 instance KnownScalTy TF64 where knownScalTy = STF64 instance KnownScalTy TBool where knownScalTy = STBool class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy class KnownMTy t where knownMTy :: SMTy t instance KnownMTy TNil where knownMTy = SMTNil instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- smtyKnown t = Dict smtyKnown :: SMTy t -> Dict (KnownMTy t) smtyKnown SMTNil = Dict smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict sscaltyKnown STI64 = Dict sscaltyKnown STF32 = Dict sscaltyKnown STF64 = Dict sscaltyKnown STBool = Dict envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) (EFst ext arg) eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) eidxEq SZ _ _ = EConst ext STBool True eidxEq (SS SZ) a b = EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) eidxEq (SS n) a b | let ty = tTup (sreplicate (SS n) tIx) = ELet ext a $ ELet ext (weakenExpr WSink b) $ EOp ext OAnd $ EPair ext (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) (ESnd ext (EVar ext ty IZ)))) (eidxEq n (EFst ext (EVar ext ty (IS IZ))) (EFst ext (EVar ext ty IZ))) emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) emap f arr | STArr n t <- typeOf arr , Dict <- styKnown t = ELet ext arr $ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) f ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) ezipWith f arr1 arr2 | STArr n t1 <- typeOf arr1 , STArr _ t2 <- typeOf arr2 , Dict <- styKnown t1 , Dict <- styKnown t2 = ELet ext arr1 $ ELet ext (weakenExpr WSink arr2) $ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) ezip arr1 arr2 = let STArr _ t1 = typeOf arr1 STArr _ t2 = typeOf arr2 in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) -- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) eshapeEmpty SZ _ = EConst ext STBool False eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) eshapeEmpty (SS n) e = ELet ext e $ EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) -- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) -- ezeroD2 t ezi = EZero ext (d2M t) ezi -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r eunPair (EPair _ e1 e2) k = k WId e1 e2 eunPair e k = elet e $ k WSink (EFst ext (evar IZ)) (ESnd ext (evar IZ)) efst :: Ex env (TPair a b) -> Ex env a efst (EPair _ e1 _) = e1 efst e = EFst ext e esnd :: Ex env (TPair a b) -> Ex env b esnd (EPair _ _ e2) = e2 esnd e = ESnd ext e elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) = ELet ext rhs body emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b emaybe e a b | STMaybe t <- typeOf e , Dict <- styKnown t = EMaybe ext a b e elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c elcase e a b c | STLEither t1 t2 <- typeOf e , Dict <- styKnown t1 , Dict <- styKnown t2 = ELCase ext e a b c evar :: KnownTy a => Idx env a -> Ex env a evar = EVar ext knownTy makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) where -- invariant: expression argument is duplicable go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) go SMTNil _ = ENil ext go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) go SMTLEither{} _ = ENil ext go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext