aboutsummaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/AST.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs705
1 files changed, 0 insertions, 705 deletions
diff --git a/src/AST.hs b/src/AST.hs
deleted file mode 100644
index ca6cdd1..0000000
--- a/src/AST.hs
+++ /dev/null
@@ -1,705 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DeriveTraversable #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImpredicativeTypes #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
-
-import Data.Functor.Const
-import Data.Functor.Identity
-import Data.Int (Int64)
-import Data.Kind (Type)
-
-import Array
-import AST.Accum
-import AST.Sparse.Types
-import AST.Types
-import AST.Weaken
-import CHAD.Types
-import Data
-
-
--- General assumption: head of the list (whatever way it is associated) is the
--- inner variable / inner array dimension. In pretty printing, the inner
--- variable / inner dimension is printed on the _right_.
---
--- All the monoid operations are unsupposed as the input to CHAD, and are
--- intended to be eliminated after simplification, so that the input program as
--- well as the output program do not contain these constructors.
--- TODO: ensure this by a "stage" type parameter.
-type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
-data Expr x env t where
- -- lambda calculus
- EVar :: x t -> STy t -> Idx env t -> Expr x env t
- ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
-
- -- base types
- EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
- EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
- ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
- ENil :: x TNil -> Expr x env TNil
- EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
- EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
- ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
- ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
- EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
- EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
-
- -- array operations
- EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
- EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
- -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
- EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
- ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
- EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
- EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
-
- -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
- -- an implementation is allowed to parallelise this thing and store the b
- -- values in some implementation-defined order.
- -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
- EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
- -> Expr x (TPair t1 t1 : env) (TPair t1 b)
- -> Expr x env t1
- -> Expr x env (TArr (S n) t1)
- -> Expr x env (TPair (TArr n t1) -- normal primal fold output
- (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
- -- Reverse derivative of EFold1Inner. The contributions to the initial
- -- element are not yet added together here; we assume a later fusion system
- -- does that for us.
- EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
- -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
- -> Expr x env (TArr n t2) -- incoming cotangent
- -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
-
- -- expression operations
- EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
- EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
- EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
- EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
- EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-
- -- custom derivatives
- -- 'b' is the part of the input of the operation that derivatives should
- -- be backpropagated to; 'a' is the inactive part. The dual field of
- -- ECustom does not allow a derivative to be generated for 'a', and hence
- -- none is propagated.
- -- No accumulators are allowed inside a, b and tape. This restriction is
- -- currently not used very much, so could be relaxed in the future; be sure
- -- to check this requirement whenever it is necessary for soundness!
- ECustom :: x t -> STy a -> STy b -> STy tape
- -> Expr x [b, a] t -- ^ regular operation
- -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
- -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
- -> Expr x env a -> Expr x env b
- -> Expr x env t
-
- -- fake halfway checkpointing
- ERecompute :: x t -> Expr x env t -> Expr x env t
-
- -- accumulation effect on monoids
- -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
- -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
- -- need to create any zeros.
- EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
- -- The 'Sparse' here is eliminated to dense by UnMonoid.
- EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
-
- -- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
- EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
- EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
- EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
-
- -- interface of abstract monoidal types
- ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
- ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
- ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
- ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-
- -- partiality
- EError :: x a -> STy a -> String -> Expr x env a
-deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
-
-type Ex = Expr (Const ())
-
-ext :: Const () a
-ext = Const ()
-
-data Commutative = Commut | Noncommut
- deriving (Show)
-
-type SOp :: Ty -> Ty -> Type
-data SOp a t where
- OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- ONot :: SOp (TScal TBool) (TScal TBool)
- OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
- OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
- OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right
- ORound64 :: SOp (TScal TF64) (TScal TI64)
- OToFl64 :: SOp (TScal TI64) (TScal TF64)
- ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
-deriving instance Show (SOp a t)
-
-opt1 :: SOp a t -> STy a
-opt1 = \case
- OAdd t -> STPair (STScal t) (STScal t)
- OMul t -> STPair (STScal t) (STScal t)
- ONeg t -> STScal t
- OLt t -> STPair (STScal t) (STScal t)
- OLe t -> STPair (STScal t) (STScal t)
- OEq t -> STPair (STScal t) (STScal t)
- ONot -> STScal STBool
- OAnd -> STPair (STScal STBool) (STScal STBool)
- OOr -> STPair (STScal STBool) (STScal STBool)
- OIf -> STScal STBool
- ORound64 -> STScal STF64
- OToFl64 -> STScal STI64
- ORecip t -> STScal t
- OExp t -> STScal t
- OLog t -> STScal t
- OIDiv t -> STPair (STScal t) (STScal t)
- OMod t -> STPair (STScal t) (STScal t)
-
-opt2 :: SOp a t -> STy t
-opt2 = \case
- OAdd t -> STScal t
- OMul t -> STScal t
- ONeg t -> STScal t
- OLt _ -> STScal STBool
- OLe _ -> STScal STBool
- OEq _ -> STScal STBool
- ONot -> STScal STBool
- OAnd -> STScal STBool
- OOr -> STScal STBool
- OIf -> STEither STNil STNil
- ORound64 -> STScal STI64
- OToFl64 -> STScal STF64
- ORecip t -> STScal t
- OExp t -> STScal t
- OLog t -> STScal t
- OIDiv t -> STScal t
- OMod t -> STScal t
-
-typeOf :: Expr x env t -> STy t
-typeOf = \case
- EVar _ t _ -> t
- ELet _ _ e -> typeOf e
-
- EPair _ a b -> STPair (typeOf a) (typeOf b)
- EFst _ e | STPair t _ <- typeOf e -> t
- ESnd _ e | STPair _ t <- typeOf e -> t
- ENil _ -> STNil
- EInl _ t2 e -> STEither (typeOf e) t2
- EInr _ t1 e -> STEither t1 (typeOf e)
- ECase _ _ a _ -> typeOf a
- ENothing _ t -> STMaybe t
- EJust _ e -> STMaybe (typeOf e)
- EMaybe _ e _ _ -> typeOf e
- ELNil _ t1 t2 -> STLEither t1 t2
- ELInl _ t2 e -> STLEither (typeOf e) t2
- ELInr _ t1 e -> STLEither t1 (typeOf e)
- ELCase _ _ a _ _ -> typeOf a
-
- EConstArr _ n t _ -> STArr n (STScal t)
- EBuild _ n _ e -> STArr n (typeOf e)
- EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)
- EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
- ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EUnit _ e -> STArr SZ (typeOf e)
- EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
- EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
- EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)
-
- EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
- EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
-
- EConst _ t _ -> STScal t
- EIdx0 _ e | STArr _ t <- typeOf e -> t
- EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
- EIdx _ e _ | STArr _ t <- typeOf e -> t
- EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
- EOp _ op _ -> opt2 op
-
- ECustom _ _ _ _ e _ _ _ _ -> typeOf e
- ERecompute _ e -> typeOf e
-
- EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ _ _ _ -> STNil
-
- EZero _ t _ -> fromSMTy t
- EDeepZero _ t _ -> fromSMTy t
- EPlus _ t _ _ -> fromSMTy t
- EOneHot _ t _ _ _ -> fromSMTy t
-
- EError _ t _ -> t
-
-extOf :: Expr x env t -> x t
-extOf = \case
- EVar x _ _ -> x
- ELet x _ _ -> x
- EPair x _ _ -> x
- EFst x _ -> x
- ESnd x _ -> x
- ENil x -> x
- EInl x _ _ -> x
- EInr x _ _ -> x
- ECase x _ _ _ -> x
- ENothing x _ -> x
- EJust x _ -> x
- EMaybe x _ _ _ -> x
- ELNil x _ _ -> x
- ELInl x _ _ -> x
- ELInr x _ _ -> x
- ELCase x _ _ _ _ -> x
- EConstArr x _ _ _ -> x
- EBuild x _ _ _ -> x
- EMap x _ _ -> x
- EFold1Inner x _ _ _ _ -> x
- ESum1Inner x _ -> x
- EUnit x _ -> x
- EReplicate1Inner x _ _ -> x
- EMaximum1Inner x _ -> x
- EMinimum1Inner x _ -> x
- EReshape x _ _ _ -> x
- EZip x _ _ -> x
- EFold1InnerD1 x _ _ _ _ -> x
- EFold1InnerD2 x _ _ _ _ -> x
- EConst x _ _ -> x
- EIdx0 x _ -> x
- EIdx1 x _ _ -> x
- EIdx x _ _ -> x
- EShape x _ -> x
- EOp x _ _ -> x
- ECustom x _ _ _ _ _ _ _ _ -> x
- ERecompute x _ -> x
- EWith x _ _ _ -> x
- EAccum x _ _ _ _ _ _ -> x
- EZero x _ _ -> x
- EDeepZero x _ _ -> x
- EPlus x _ _ _ -> x
- EOneHot x _ _ _ _ -> x
- EError x _ _ -> x
-
-mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
-mapExt f = runIdentity . travExt (Identity . f)
-
-{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
-travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
-travExt f = \case
- EVar x t i -> EVar <$> f x <*> pure t <*> pure i
- ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
- EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b
- EFst x e -> EFst <$> f x <*> travExt f e
- ESnd x e -> ESnd <$> f x <*> travExt f e
- ENil x -> ENil <$> f x
- EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e
- EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e
- ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b
- ENothing x t -> ENothing <$> f x <*> pure t
- EJust x e -> EJust <$> f x <*> travExt f e
- EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e
- ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2
- ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e
- ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e
- ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
- EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
- EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
- EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b
- EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
- EUnit x e -> EUnit <$> f x <*> travExt f e
- EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
- EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
- EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
- EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b
- EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
- EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- EConst x t v -> EConst <$> f x <*> pure t <*> pure v
- EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
- EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
- EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es
- EShape x e -> EShape <$> f x <*> travExt f e
- EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
- ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
- ERecompute x e -> ERecompute <$> f x <*> travExt f e
- EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
- EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3
- EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
- EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e
- EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
- EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
- EError x t s -> EError <$> f x <*> pure t <*> pure s
-
-substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
-substInline repl =
- subst $ \x t -> \case IZ -> repl
- IS i -> EVar x t i
-
-subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t
-subst0 repl =
- subst $ \_ t -> \case IZ -> repl
- IS i -> EVar ext t (IS i)
-
-subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
- -> Expr x env t -> Expr x env' t
-subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
-
-subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
- -> env' :> envOut
- -> Expr x env t
- -> Expr x envOut t
-subst' f w = \case
- EVar x t i -> f x t w i
- ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
- EPair x a b -> EPair x (subst' f w a) (subst' f w b)
- EFst x e -> EFst x (subst' f w e)
- ESnd x e -> ESnd x (subst' f w e)
- ENil x -> ENil x
- EInl x t e -> EInl x t (subst' f w e)
- EInr x t e -> EInr x t (subst' f w e)
- ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
- ENothing x t -> ENothing x t
- EJust x e -> EJust x (subst' f w e)
- EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
- ELNil x t1 t2 -> ELNil x t1 t2
- ELInl x t e -> ELInl x t (subst' f w e)
- ELInr x t e -> ELInr x t (subst' f w e)
- ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
- EConstArr x n t a -> EConstArr x n t a
- EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)
- EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
- ESum1Inner x e -> ESum1Inner x (subst' f w e)
- EUnit x e -> EUnit x (subst' f w e)
- EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
- EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
- EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
- EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
- EZip x a b -> EZip x (subst' f w a) (subst' f w b)
- EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
- EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
- EConst x t v -> EConst x t v
- EIdx0 x e -> EIdx0 x (subst' f w e)
- EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
- EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
- EShape x e -> EShape x (subst' f w e)
- EOp x op e -> EOp x op (subst' f w e)
- ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
- ERecompute x e -> ERecompute x (subst' f w e)
- EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3)
- EZero x t e -> EZero x t (subst' f w e)
- EDeepZero x t e -> EDeepZero x t (subst' f w e)
- EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
- EError x t s -> EError x t s
- where
- sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
- sinkF f' x' t w' = \case
- IZ -> EVar x' t (w' @> IZ)
- IS i -> f' x' t (WPop w') i
-
-weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
-weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-
-class KnownScalTy t where knownScalTy :: SScalTy t
-instance KnownScalTy TI32 where knownScalTy = STI32
-instance KnownScalTy TI64 where knownScalTy = STI64
-instance KnownScalTy TF32 where knownScalTy = STF32
-instance KnownScalTy TF64 where knownScalTy = STF64
-instance KnownScalTy TBool where knownScalTy = STBool
-
-class KnownTy t where knownTy :: STy t
-instance KnownTy TNil where knownTy = STNil
-instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
-instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
-instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
-instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
-
-class KnownMTy t where knownMTy :: SMTy t
-instance KnownMTy TNil where knownMTy = SMTNil
-instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
-instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
-instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
-instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
-instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
-
-class KnownEnv env where knownEnv :: SList STy env
-instance KnownEnv '[] where knownEnv = SNil
-instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
-
-styKnown :: STy t -> Dict (KnownTy t)
-styKnown STNil = Dict
-styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STMaybe t) | Dict <- styKnown t = Dict
-styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
-styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
-styKnown (STAccum t) | Dict <- smtyKnown t = Dict
-
-smtyKnown :: SMTy t -> Dict (KnownMTy t)
-smtyKnown SMTNil = Dict
-smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
-smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
-smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
-smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
-smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
-
-sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
-sscaltyKnown STI32 = Dict
-sscaltyKnown STI64 = Dict
-sscaltyKnown STF32 = Dict
-sscaltyKnown STF64 = Dict
-sscaltyKnown STBool = Dict
-
-envKnown :: SList STy env -> Dict (KnownEnv env)
-envKnown SNil = Dict
-envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
-
-cheapExpr :: Expr x env t -> Bool
-cheapExpr = \case
- EVar{} -> True
- ENil{} -> True
- EConst{} -> True
- EFst _ e -> cheapExpr e
- ESnd _ e -> cheapExpr e
- EUnit _ e -> cheapExpr e
- _ -> False
-
-eTup :: SList (Ex env) list -> Ex env (Tup list)
-eTup = mkTup (ENil ext) (EPair ext)
-
-ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
-ebuildUp1 n sh size f =
- EBuild ext (SS n) (EPair ext sh size) $
- let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
- in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
- (EFst ext arg)
-
-eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
-eidxEq SZ _ _ = EConst ext STBool True
-eidxEq (SS SZ) a b =
- EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b))
-eidxEq (SS n) a b
- | let ty = tTup (sreplicate (SS n) tIx)
- = ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EOp ext OAnd $ EPair ext
- (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ)))
- (ESnd ext (EVar ext ty IZ))))
- (eidxEq n (EFst ext (EVar ext ty (IS IZ)))
- (EFst ext (EVar ext ty IZ)))
-
-emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
-emap f arr
- | STArr _ t <- typeOf arr
- , Dict <- styKnown t
- = EMap ext f arr
-
-ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
-ezipWith f arr1 arr2
- | STArr _ t1 <- typeOf arr1
- , STArr _ t2 <- typeOf arr2
- , Dict <- styKnown t1
- , Dict <- styKnown t2
- = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ)
- IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ)
- IS (IS i) -> EVar ext t (IS i))
- f)
- (EZip ext arr1 arr2)
-
-ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip = EZip ext
-
-eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
-eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
-
--- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
-eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
-eshapeEmpty SZ _ = EConst ext STBool False
-eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
-eshapeEmpty (SS n) e =
- ELet ext e $
- EOp ext OAnd (EPair ext
- (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
- (EConst ext STI64 0)))
- (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
-
-eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
-eshapeConst ShNil = ENil ext
-eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
-
-eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx
-eshapeProd SZ _ = EConst ext STI64 1
-eshapeProd (SS SZ) e = ESnd ext e
-eshapeProd (SS n) e =
- eunPair e $ \_ e1 e2 ->
- EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2)
-
-eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t)
-eflatten e =
- let STArr n _ = typeOf e
- in elet e $
- EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ)
-
--- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
--- ezeroD2 t ezi = EZero ext (d2M t) ezi
-
--- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
--- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
-
--- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
--- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev
-
-eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
-eunPair (EPair _ e1 e2) k = k WId e1 e2
-eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e)
-eunPair e k =
- elet e $
- k WSink
- (EFst ext (evar IZ))
- (ESnd ext (evar IZ))
-
-efst :: Ex env (TPair a b) -> Ex env a
-efst (EPair _ e1 _) = e1
-efst e = EFst ext e
-
-esnd :: Ex env (TPair a b) -> Ex env b
-esnd (EPair _ _ e2) = e2
-esnd e = ESnd ext e
-
-elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
-elet rhs body
- | Dict <- styKnown (typeOf rhs)
- = if cheapExpr rhs
- then substInline rhs body
- else ELet ext rhs body
-
--- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
-use :: Ex env a -> Ex env b -> Ex env b
-use a b = elet a $ weakenExpr WSink b
-
-emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
-emaybe e a b
- | STMaybe t <- typeOf e
- , Dict <- styKnown t
- = EMaybe ext a b e
-
-ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
-ecase e a b
- | STEither t1 t2 <- typeOf e
- , Dict <- styKnown t1
- , Dict <- styKnown t2
- = ECase ext e a b
-
-elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
-elcase e a b c
- | STLEither t1 t2 <- typeOf e
- , Dict <- styKnown t1
- , Dict <- styKnown t2
- = ELCase ext e a b c
-
-evar :: KnownTy a => Idx env a -> Ex env a
-evar = EVar ext knownTy
-
-makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
-makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
- where
- -- invariant: expression argument is duplicable
- go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
- go SMTNil _ = ENil ext
- go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
- go SMTLEither{} _ = ENil ext
- go SMTMaybe{} _ = ENil ext
- go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
- go SMTScal{} _ = ENil ext
-
-splitSparsePair
- :: -- given a sparsity
- STy (TPair a b) -> Sparse (TPair a b) t'
- -> (forall a' b'.
- -- I give you back two sparsities for a and b
- Sparse a a' -> Sparse b b'
- -- furthermore, I tell you that either your t' is already this (a', b') pair...
- -> Either
- (t' :~: TPair a' b')
- -- or I tell you how to construct a' and b' from t', given an actual t'
- (forall r' env.
- Idx env t'
- -> (forall env'.
- (forall c. Ex env' c -> Ex env c)
- -> Ex env' a' -> Ex env' b' -> r')
- -> r')
- -> r)
- -> r
-splitSparsePair _ SpAbsent k =
- k SpAbsent SpAbsent $ Right $ \_ k2 ->
- k2 id (ENil ext) (ENil ext)
-splitSparsePair _ (SpPair s1 s2) k1 =
- k1 s1 s2 $ Left Refl
-splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k =
- let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in
- k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 ->
- k2 (elet $
- emaybe (EVar ext (STMaybe (applySparse s t)) i)
- (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2)))
- (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ)))))
- (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ))
-
-splitSparsePair _ (SpSparse SpAbsent) k =
- k SpAbsent SpAbsent $ Right $ \_ k2 ->
- k2 id (ENil ext) (ENil ext)
--- -- TODO: having to handle sparse-of-sparse at all is ridiculous
-splitSparsePair t (SpSparse (SpSparse s)) k =
- splitSparsePair t (SpSparse s) $ \s1 s2 eres ->
- k s1 s2 $ Right $ \i k2 ->
- case eres of
- Left refl -> case refl of {}
- Right f ->
- f IZ $ \wrap e1 e2 ->
- k2 (\body ->
- elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i)
- (ENothing ext (applySparse s t))
- (evar IZ)) $
- wrap body)
- e1 e2