aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/AST.hs')
-rw-r--r--src/CHAD/AST.hs705
1 files changed, 705 insertions, 0 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
new file mode 100644
index 0000000..aa6aa96
--- /dev/null
+++ b/src/CHAD/AST.hs
@@ -0,0 +1,705 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where
+
+import Data.Functor.Const
+import Data.Functor.Identity
+import Data.Int (Int64)
+import Data.Kind (Type)
+
+import CHAD.Array
+import CHAD.AST.Accum
+import CHAD.AST.Sparse.Types
+import CHAD.AST.Types
+import CHAD.AST.Weaken
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+-- General assumption: head of the list (whatever way it is associated) is the
+-- inner variable / inner array dimension. In pretty printing, the inner
+-- variable / inner dimension is printed on the _right_.
+--
+-- All the monoid operations are unsupposed as the input to CHAD, and are
+-- intended to be eliminated after simplification, so that the input program as
+-- well as the output program do not contain these constructors.
+-- TODO: ensure this by a "stage" type parameter.
+type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data Expr x env t where
+ -- lambda calculus
+ EVar :: x t -> STy t -> Idx env t -> Expr x env t
+ ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
+
+ -- base types
+ EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
+ EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
+ ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
+ ENil :: x TNil -> Expr x env TNil
+ EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
+ EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
+ ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+ ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
+ EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
+ EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
+
+ -- array operations
+ EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
+ EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
+ EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
+ -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
+ EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
+ EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
+ EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
+ EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
+
+ -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
+ -- an implementation is allowed to parallelise this thing and store the b
+ -- values in some implementation-defined order.
+ -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
+ -> Expr x (TPair t1 t1 : env) (TPair t1 b)
+ -> Expr x env t1
+ -> Expr x env (TArr (S n) t1)
+ -> Expr x env (TPair (TArr n t1) -- normal primal fold output
+ (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
+ -- Reverse derivative of EFold1Inner. The contributions to the initial
+ -- element are not yet added together here; we assume a later fusion system
+ -- does that for us.
+ EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
+ -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
+ -> Expr x env (TArr n t2) -- incoming cotangent
+ -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
+
+ -- expression operations
+ EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
+ EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
+ EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
+ EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
+ EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
+ EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
+
+ -- custom derivatives
+ -- 'b' is the part of the input of the operation that derivatives should
+ -- be backpropagated to; 'a' is the inactive part. The dual field of
+ -- ECustom does not allow a derivative to be generated for 'a', and hence
+ -- none is propagated.
+ -- No accumulators are allowed inside a, b and tape. This restriction is
+ -- currently not used very much, so could be relaxed in the future; be sure
+ -- to check this requirement whenever it is necessary for soundness!
+ ECustom :: x t -> STy a -> STy b -> STy tape
+ -> Expr x [b, a] t -- ^ regular operation
+ -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr x env a -> Expr x env b
+ -> Expr x env t
+
+ -- fake halfway checkpointing
+ ERecompute :: x t -> Expr x env t -> Expr x env t
+
+ -- accumulation effect on monoids
+ -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
+ -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
+ -- need to create any zeros.
+ EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ -- The 'Sparse' here is eliminated to dense by UnMonoid.
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
+
+ -- monoidal operations (to be desugared to regular operations after simplification)
+ EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
+ EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
+ EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+
+ -- interface of abstract monoidal types
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
+ ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+
+ -- partiality
+ EError :: x a -> STy a -> String -> Expr x env a
+deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+
+type Ex = Expr (Const ())
+
+ext :: Const () a
+ext = Const ()
+
+data Commutative = Commut | Noncommut
+ deriving (Show)
+
+type SOp :: Ty -> Ty -> Type
+data SOp a t where
+ OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ ONot :: SOp (TScal TBool) (TScal TBool)
+ OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
+ OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
+ OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right
+ ORound64 :: SOp (TScal TF64) (TScal TI64)
+ OToFl64 :: SOp (TScal TI64) (TScal TF64)
+ ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+deriving instance Show (SOp a t)
+
+opt1 :: SOp a t -> STy a
+opt1 = \case
+ OAdd t -> STPair (STScal t) (STScal t)
+ OMul t -> STPair (STScal t) (STScal t)
+ ONeg t -> STScal t
+ OLt t -> STPair (STScal t) (STScal t)
+ OLe t -> STPair (STScal t) (STScal t)
+ OEq t -> STPair (STScal t) (STScal t)
+ ONot -> STScal STBool
+ OAnd -> STPair (STScal STBool) (STScal STBool)
+ OOr -> STPair (STScal STBool) (STScal STBool)
+ OIf -> STScal STBool
+ ORound64 -> STScal STF64
+ OToFl64 -> STScal STI64
+ ORecip t -> STScal t
+ OExp t -> STScal t
+ OLog t -> STScal t
+ OIDiv t -> STPair (STScal t) (STScal t)
+ OMod t -> STPair (STScal t) (STScal t)
+
+opt2 :: SOp a t -> STy t
+opt2 = \case
+ OAdd t -> STScal t
+ OMul t -> STScal t
+ ONeg t -> STScal t
+ OLt _ -> STScal STBool
+ OLe _ -> STScal STBool
+ OEq _ -> STScal STBool
+ ONot -> STScal STBool
+ OAnd -> STScal STBool
+ OOr -> STScal STBool
+ OIf -> STEither STNil STNil
+ ORound64 -> STScal STI64
+ OToFl64 -> STScal STF64
+ ORecip t -> STScal t
+ OExp t -> STScal t
+ OLog t -> STScal t
+ OIDiv t -> STScal t
+ OMod t -> STScal t
+
+typeOf :: Expr x env t -> STy t
+typeOf = \case
+ EVar _ t _ -> t
+ ELet _ _ e -> typeOf e
+
+ EPair _ a b -> STPair (typeOf a) (typeOf b)
+ EFst _ e | STPair t _ <- typeOf e -> t
+ ESnd _ e | STPair _ t <- typeOf e -> t
+ ENil _ -> STNil
+ EInl _ t2 e -> STEither (typeOf e) t2
+ EInr _ t1 e -> STEither t1 (typeOf e)
+ ECase _ _ a _ -> typeOf a
+ ENothing _ t -> STMaybe t
+ EJust _ e -> STMaybe (typeOf e)
+ EMaybe _ e _ _ -> typeOf e
+ ELNil _ t1 t2 -> STLEither t1 t2
+ ELInl _ t2 e -> STLEither (typeOf e) t2
+ ELInr _ t1 e -> STLEither t1 (typeOf e)
+ ELCase _ _ a _ _ -> typeOf a
+
+ EConstArr _ n t _ -> STArr n (STScal t)
+ EBuild _ n _ e -> STArr n (typeOf e)
+ EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)
+ EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EUnit _ e -> STArr SZ (typeOf e)
+ EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
+ EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)
+
+ EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
+ EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
+
+ EConst _ t _ -> STScal t
+ EIdx0 _ e | STArr _ t <- typeOf e -> t
+ EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
+ EIdx _ e _ | STArr _ t <- typeOf e -> t
+ EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
+ EOp _ op _ -> opt2 op
+
+ ECustom _ _ _ _ e _ _ _ _ -> typeOf e
+ ERecompute _ e -> typeOf e
+
+ EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
+ EAccum _ _ _ _ _ _ _ -> STNil
+
+ EZero _ t _ -> fromSMTy t
+ EDeepZero _ t _ -> fromSMTy t
+ EPlus _ t _ _ -> fromSMTy t
+ EOneHot _ t _ _ _ -> fromSMTy t
+
+ EError _ t _ -> t
+
+extOf :: Expr x env t -> x t
+extOf = \case
+ EVar x _ _ -> x
+ ELet x _ _ -> x
+ EPair x _ _ -> x
+ EFst x _ -> x
+ ESnd x _ -> x
+ ENil x -> x
+ EInl x _ _ -> x
+ EInr x _ _ -> x
+ ECase x _ _ _ -> x
+ ENothing x _ -> x
+ EJust x _ -> x
+ EMaybe x _ _ _ -> x
+ ELNil x _ _ -> x
+ ELInl x _ _ -> x
+ ELInr x _ _ -> x
+ ELCase x _ _ _ _ -> x
+ EConstArr x _ _ _ -> x
+ EBuild x _ _ _ -> x
+ EMap x _ _ -> x
+ EFold1Inner x _ _ _ _ -> x
+ ESum1Inner x _ -> x
+ EUnit x _ -> x
+ EReplicate1Inner x _ _ -> x
+ EMaximum1Inner x _ -> x
+ EMinimum1Inner x _ -> x
+ EReshape x _ _ _ -> x
+ EZip x _ _ -> x
+ EFold1InnerD1 x _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ -> x
+ EConst x _ _ -> x
+ EIdx0 x _ -> x
+ EIdx1 x _ _ -> x
+ EIdx x _ _ -> x
+ EShape x _ -> x
+ EOp x _ _ -> x
+ ECustom x _ _ _ _ _ _ _ _ -> x
+ ERecompute x _ -> x
+ EWith x _ _ _ -> x
+ EAccum x _ _ _ _ _ _ -> x
+ EZero x _ _ -> x
+ EDeepZero x _ _ -> x
+ EPlus x _ _ _ -> x
+ EOneHot x _ _ _ _ -> x
+ EError x _ _ -> x
+
+mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt f = runIdentity . travExt (Identity . f)
+
+{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
+travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt f = \case
+ EVar x t i -> EVar <$> f x <*> pure t <*> pure i
+ ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
+ EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b
+ EFst x e -> EFst <$> f x <*> travExt f e
+ ESnd x e -> ESnd <$> f x <*> travExt f e
+ ENil x -> ENil <$> f x
+ EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e
+ EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e
+ ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b
+ ENothing x t -> ENothing <$> f x <*> pure t
+ EJust x e -> EJust <$> f x <*> travExt f e
+ EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e
+ ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2
+ ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e
+ ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e
+ ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
+ EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
+ EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b
+ EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
+ EUnit x e -> EUnit <$> f x <*> travExt f e
+ EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
+ EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
+ EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b
+ EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EConst x t v -> EConst <$> f x <*> pure t <*> pure v
+ EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
+ EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
+ EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es
+ EShape x e -> EShape <$> f x <*> travExt f e
+ EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
+ ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
+ ERecompute x e -> ERecompute <$> f x <*> travExt f e
+ EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
+ EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3
+ EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
+ EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e
+ EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
+ EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
+ EError x t s -> EError <$> f x <*> pure t <*> pure s
+
+substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+substInline repl =
+ subst $ \x t -> \case IZ -> repl
+ IS i -> EVar x t i
+
+subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t
+subst0 repl =
+ subst $ \_ t -> \case IZ -> repl
+ IS i -> EVar ext t (IS i)
+
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
+ -> Expr x env t -> Expr x env' t
+subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
+
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+ -> env' :> envOut
+ -> Expr x env t
+ -> Expr x envOut t
+subst' f w = \case
+ EVar x t i -> f x t w i
+ ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
+ EPair x a b -> EPair x (subst' f w a) (subst' f w b)
+ EFst x e -> EFst x (subst' f w e)
+ ESnd x e -> ESnd x (subst' f w e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (subst' f w e)
+ EInr x t e -> EInr x t (subst' f w e)
+ ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ ENothing x t -> ENothing x t
+ EJust x e -> EJust x (subst' f w e)
+ EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (subst' f w e)
+ ELInr x t e -> ELInr x t (subst' f w e)
+ ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
+ EConstArr x n t a -> EConstArr x n t a
+ EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)
+ EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
+ ESum1Inner x e -> ESum1Inner x (subst' f w e)
+ EUnit x e -> EUnit x (subst' f w e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
+ EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
+ EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
+ EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
+ EZip x a b -> EZip x (subst' f w a) (subst' f w b)
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EConst x t v -> EConst x t v
+ EIdx0 x e -> EIdx0 x (subst' f w e)
+ EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
+ EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
+ EShape x e -> EShape x (subst' f w e)
+ EOp x op e -> EOp x op (subst' f w e)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
+ ERecompute x e -> ERecompute x (subst' f w e)
+ EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3)
+ EZero x t e -> EZero x t (subst' f w e)
+ EDeepZero x t e -> EDeepZero x t (subst' f w e)
+ EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
+ EError x t s -> EError x t s
+ where
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF f' x' t w' = \case
+ IZ -> EVar x' t (w' @> IZ)
+ IS i -> f' x' t (WPop w') i
+
+weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
+weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
+
+class KnownScalTy t where knownScalTy :: SScalTy t
+instance KnownScalTy TI32 where knownScalTy = STI32
+instance KnownScalTy TI64 where knownScalTy = STI64
+instance KnownScalTy TF32 where knownScalTy = STF32
+instance KnownScalTy TF64 where knownScalTy = STF64
+instance KnownScalTy TBool where knownScalTy = STBool
+
+class KnownTy t where knownTy :: STy t
+instance KnownTy TNil where knownTy = STNil
+instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
+instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
+instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
+instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
+instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+
+class KnownMTy t where knownMTy :: SMTy t
+instance KnownMTy TNil where knownMTy = SMTNil
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
+instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
+instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
+instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
+
+class KnownEnv env where knownEnv :: SList STy env
+instance KnownEnv '[] where knownEnv = SNil
+instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+
+styKnown :: STy t -> Dict (KnownTy t)
+styKnown STNil = Dict
+styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STMaybe t) | Dict <- styKnown t = Dict
+styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
+styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
+styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+
+smtyKnown :: SMTy t -> Dict (KnownMTy t)
+smtyKnown SMTNil = Dict
+smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
+smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
+smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
+
+sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
+sscaltyKnown STI32 = Dict
+sscaltyKnown STI64 = Dict
+sscaltyKnown STF32 = Dict
+sscaltyKnown STF64 = Dict
+sscaltyKnown STBool = Dict
+
+envKnown :: SList STy env -> Dict (KnownEnv env)
+envKnown SNil = Dict
+envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
+
+cheapExpr :: Expr x env t -> Bool
+cheapExpr = \case
+ EVar{} -> True
+ ENil{} -> True
+ EConst{} -> True
+ EFst _ e -> cheapExpr e
+ ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
+ _ -> False
+
+eTup :: SList (Ex env) list -> Ex env (Tup list)
+eTup = mkTup (ENil ext) (EPair ext)
+
+ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
+ebuildUp1 n sh size f =
+ EBuild ext (SS n) (EPair ext sh size) $
+ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
+ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
+ (EFst ext arg)
+
+eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eidxEq SZ _ _ = EConst ext STBool True
+eidxEq (SS SZ) a b =
+ EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b))
+eidxEq (SS n) a b
+ | let ty = tTup (sreplicate (SS n) tIx)
+ = ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EOp ext OAnd $ EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ)))
+ (ESnd ext (EVar ext ty IZ))))
+ (eidxEq n (EFst ext (EVar ext ty (IS IZ)))
+ (EFst ext (EVar ext ty IZ)))
+
+emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
+emap f arr
+ | STArr _ t <- typeOf arr
+ , Dict <- styKnown t
+ = EMap ext f arr
+
+ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
+ezipWith f arr1 arr2
+ | STArr _ t1 <- typeOf arr1
+ , STArr _ t2 <- typeOf arr2
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ)
+ IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ)
+ IS (IS i) -> EVar ext t (IS i))
+ f)
+ (EZip ext arr1 arr2)
+
+ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
+ezip = EZip ext
+
+eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
+eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
+
+-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
+eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eshapeEmpty SZ _ = EConst ext STBool False
+eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
+eshapeEmpty (SS n) e =
+ ELet ext e $
+ EOp ext OAnd (EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
+ (EConst ext STI64 0)))
+ (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
+eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
+eshapeConst ShNil = ENil ext
+eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
+
+eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx
+eshapeProd SZ _ = EConst ext STI64 1
+eshapeProd (SS SZ) e = ESnd ext e
+eshapeProd (SS n) e =
+ eunPair e $ \_ e1 e2 ->
+ EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2)
+
+eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t)
+eflatten e =
+ let STArr n _ = typeOf e
+ in elet e $
+ EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ)
+
+-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
+-- ezeroD2 t ezi = EZero ext (d2M t) ezi
+
+-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
+-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
+
+-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
+-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev
+
+eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
+eunPair (EPair _ e1 e2) k = k WId e1 e2
+eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e)
+eunPair e k =
+ elet e $
+ k WSink
+ (EFst ext (evar IZ))
+ (ESnd ext (evar IZ))
+
+efst :: Ex env (TPair a b) -> Ex env a
+efst (EPair _ e1 _) = e1
+efst e = EFst ext e
+
+esnd :: Ex env (TPair a b) -> Ex env b
+esnd (EPair _ _ e2) = e2
+esnd e = ESnd ext e
+
+elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
+elet rhs body
+ | Dict <- styKnown (typeOf rhs)
+ = if cheapExpr rhs
+ then substInline rhs body
+ else ELet ext rhs body
+
+-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
+emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
+emaybe e a b
+ | STMaybe t <- typeOf e
+ , Dict <- styKnown t
+ = EMaybe ext a b e
+
+ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+ecase e a b
+ | STEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ECase ext e a b
+
+elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+elcase e a b c
+ | STLEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELCase ext e a b c
+
+evar :: KnownTy a => Idx env a -> Ex env a
+evar = EVar ext knownTy
+
+makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
+ where
+ -- invariant: expression argument is duplicable
+ go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+ go SMTNil _ = ENil ext
+ go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
+ go SMTLEither{} _ = ENil ext
+ go SMTMaybe{} _ = ENil ext
+ go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
+ go SMTScal{} _ = ENil ext
+
+splitSparsePair
+ :: -- given a sparsity
+ STy (TPair a b) -> Sparse (TPair a b) t'
+ -> (forall a' b'.
+ -- I give you back two sparsities for a and b
+ Sparse a a' -> Sparse b b'
+ -- furthermore, I tell you that either your t' is already this (a', b') pair...
+ -> Either
+ (t' :~: TPair a' b')
+ -- or I tell you how to construct a' and b' from t', given an actual t'
+ (forall r' env.
+ Idx env t'
+ -> (forall env'.
+ (forall c. Ex env' c -> Ex env c)
+ -> Ex env' a' -> Ex env' b' -> r')
+ -> r')
+ -> r)
+ -> r
+splitSparsePair _ SpAbsent k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+splitSparsePair _ (SpPair s1 s2) k1 =
+ k1 s1 s2 $ Left Refl
+splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k =
+ let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in
+ k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 ->
+ k2 (elet $
+ emaybe (EVar ext (STMaybe (applySparse s t)) i)
+ (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2)))
+ (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ)))))
+ (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ))
+
+splitSparsePair _ (SpSparse SpAbsent) k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+-- -- TODO: having to handle sparse-of-sparse at all is ridiculous
+splitSparsePair t (SpSparse (SpSparse s)) k =
+ splitSparsePair t (SpSparse s) $ \s1 s2 eres ->
+ k s1 s2 $ Right $ \i k2 ->
+ case eres of
+ Left refl -> case refl of {}
+ Right f ->
+ f IZ $ \wrap e1 e2 ->
+ k2 (\body ->
+ elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i)
+ (ENothing ext (applySparse s t))
+ (evar IZ)) $
+ wrap body)
+ e1 e2