From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/CHAD/AST.hs | 705 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 705 insertions(+) create mode 100644 src/CHAD/AST.hs (limited to 'src/CHAD/AST.hs') diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs new file mode 100644 index 0000000..aa6aa96 --- /dev/null +++ b/src/CHAD/AST.hs @@ -0,0 +1,705 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where + +import Data.Functor.Const +import Data.Functor.Identity +import Data.Int (Int64) +import Data.Kind (Type) + +import CHAD.Array +import CHAD.AST.Accum +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +-- General assumption: head of the list (whatever way it is associated) is the +-- inner variable / inner array dimension. In pretty printing, the inner +-- variable / inner dimension is printed on the _right_. +-- +-- All the monoid operations are unsupposed as the input to CHAD, and are +-- intended to be eliminated after simplification, so that the input program as +-- well as the output program do not contain these constructors. +-- TODO: ensure this by a "stage" type parameter. +type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr x env t where + -- lambda calculus + EVar :: x t -> STy t -> Idx env t -> Expr x env t + ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + + -- base types + EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) + EFst :: x a -> Expr x env (TPair a b) -> Expr x env a + ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b + ENil :: x TNil -> Expr x env TNil + EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) + EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) + ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) + EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + + -- array operations + EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) + EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) + EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) + EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (TPair t1 t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array + + -- expression operations + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) + EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t + EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) + EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t + EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) + EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + + -- custom derivatives + -- 'b' is the part of the input of the operation that derivatives should + -- be backpropagated to; 'a' is the inactive part. The dual field of + -- ECustom does not allow a derivative to be generated for 'a', and hence + -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! + ECustom :: x t -> STy a -> STy b -> STy tape + -> Expr x [b, a] t -- ^ regular operation + -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative + -> Expr x env a -> Expr x env b + -> Expr x env t + + -- fake halfway checkpointing + ERecompute :: x t -> Expr x env t -> Expr x env t + + -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. + EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil + + -- monoidal operations (to be desugared to regular operations after simplification) + EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t + EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + + -- interface of abstract monoidal types + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + + -- partiality + EError :: x a -> STy a -> String -> Expr x env a +deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + +type Ex = Expr (Const ()) + +ext :: Const () a +ext = Const () + +data Commutative = Commut | Noncommut + deriving (Show) + +type SOp :: Ty -> Ty -> Type +data SOp a t where + OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + ONot :: SOp (TScal TBool) (TScal TBool) + OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right + ORound64 :: SOp (TScal TF64) (TScal TI64) + OToFl64 :: SOp (TScal TI64) (TScal TF64) + ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) +deriving instance Show (SOp a t) + +opt1 :: SOp a t -> STy a +opt1 = \case + OAdd t -> STPair (STScal t) (STScal t) + OMul t -> STPair (STScal t) (STScal t) + ONeg t -> STScal t + OLt t -> STPair (STScal t) (STScal t) + OLe t -> STPair (STScal t) (STScal t) + OEq t -> STPair (STScal t) (STScal t) + ONot -> STScal STBool + OAnd -> STPair (STScal STBool) (STScal STBool) + OOr -> STPair (STScal STBool) (STScal STBool) + OIf -> STScal STBool + ORound64 -> STScal STF64 + OToFl64 -> STScal STI64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STPair (STScal t) (STScal t) + OMod t -> STPair (STScal t) (STScal t) + +opt2 :: SOp a t -> STy t +opt2 = \case + OAdd t -> STScal t + OMul t -> STScal t + ONeg t -> STScal t + OLt _ -> STScal STBool + OLe _ -> STScal STBool + OEq _ -> STScal STBool + ONot -> STScal STBool + OAnd -> STScal STBool + OOr -> STScal STBool + OIf -> STEither STNil STNil + ORound64 -> STScal STI64 + OToFl64 -> STScal STF64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STScal t + OMod t -> STScal t + +typeOf :: Expr x env t -> STy t +typeOf = \case + EVar _ t _ -> t + ELet _ _ e -> typeOf e + + EPair _ a b -> STPair (typeOf a) (typeOf b) + EFst _ e | STPair t _ <- typeOf e -> t + ESnd _ e | STPair _ t <- typeOf e -> t + ENil _ -> STNil + EInl _ t2 e -> STEither (typeOf e) t2 + EInr _ t1 e -> STEither t1 (typeOf e) + ECase _ _ a _ -> typeOf a + ENothing _ t -> STMaybe t + EJust _ e -> STMaybe (typeOf e) + EMaybe _ e _ _ -> typeOf e + ELNil _ t1 t2 -> STLEither t1 t2 + ELInl _ t2 e -> STLEither (typeOf e) t2 + ELInr _ t1 e -> STLEither t1 (typeOf e) + ELCase _ _ a _ _ -> typeOf a + + EConstArr _ n t _ -> STArr n (STScal t) + EBuild _ n _ e -> STArr n (typeOf e) + EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) + EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EUnit _ e -> STArr SZ (typeOf e) + EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t + EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) + + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) + + EConst _ t _ -> STScal t + EIdx0 _ e | STArr _ t <- typeOf e -> t + EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t + EIdx _ e _ | STArr _ t <- typeOf e -> t + EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) + EOp _ op _ -> opt2 op + + ECustom _ _ _ _ e _ _ _ _ -> typeOf e + ERecompute _ e -> typeOf e + + EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ _ _ _ _ -> STNil + + EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t + EPlus _ t _ _ -> fromSMTy t + EOneHot _ t _ _ _ -> fromSMTy t + + EError _ t _ -> t + +extOf :: Expr x env t -> x t +extOf = \case + EVar x _ _ -> x + ELet x _ _ -> x + EPair x _ _ -> x + EFst x _ -> x + ESnd x _ -> x + ENil x -> x + EInl x _ _ -> x + EInr x _ _ -> x + ECase x _ _ _ -> x + ENothing x _ -> x + EJust x _ -> x + EMaybe x _ _ _ -> x + ELNil x _ _ -> x + ELInl x _ _ -> x + ELInr x _ _ -> x + ELCase x _ _ _ _ -> x + EConstArr x _ _ _ -> x + EBuild x _ _ _ -> x + EMap x _ _ -> x + EFold1Inner x _ _ _ _ -> x + ESum1Inner x _ -> x + EUnit x _ -> x + EReplicate1Inner x _ _ -> x + EMaximum1Inner x _ -> x + EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x + EZip x _ _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x + EConst x _ _ -> x + EIdx0 x _ -> x + EIdx1 x _ _ -> x + EIdx x _ _ -> x + EShape x _ -> x + EOp x _ _ -> x + ECustom x _ _ _ _ _ _ _ _ -> x + ERecompute x _ -> x + EWith x _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x + EZero x _ _ -> x + EDeepZero x _ _ -> x + EPlus x _ _ _ -> x + EOneHot x _ _ _ _ -> x + EError x _ _ -> x + +mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case + EVar x t i -> EVar <$> f x <*> pure t <*> pure i + ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body + EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b + EFst x e -> EFst <$> f x <*> travExt f e + ESnd x e -> ESnd <$> f x <*> travExt f e + ENil x -> ENil <$> f x + EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e + EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e + ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b + ENothing x t -> ENothing <$> f x <*> pure t + EJust x e -> EJust <$> f x <*> travExt f e + EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e + ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 + ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e + ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e + ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c + EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a + EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b + EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e + EUnit x e -> EUnit <$> f x <*> travExt f e + EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b + EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e + EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EConst x t v -> EConst <$> f x <*> pure t <*> pure v + EIdx0 x e -> EIdx0 <$> f x <*> travExt f e + EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b + EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es + EShape x e -> EShape <$> f x <*> travExt f e + EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e + ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + ERecompute x e -> ERecompute <$> f x <*> travExt f e + EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 + EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e + EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b + EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b + EError x t s -> EError <$> f x <*> pure t <*> pure s + +substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t +substInline repl = + subst $ \x t -> \case IZ -> repl + IS i -> EVar x t i + +subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t +subst0 repl = + subst $ \_ t -> \case IZ -> repl + IS i -> EVar ext t (IS i) + +subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) + -> Expr x env t -> Expr x env' t +subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId + +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) + -> env' :> envOut + -> Expr x env t + -> Expr x envOut t +subst' f w = \case + EVar x t i -> f x t w i + ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) + EPair x a b -> EPair x (subst' f w a) (subst' f w b) + EFst x e -> EFst x (subst' f w e) + ESnd x e -> ESnd x (subst' f w e) + ENil x -> ENil x + EInl x t e -> EInl x t (subst' f w e) + EInr x t e -> EInr x t (subst' f w e) + ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + ENothing x t -> ENothing x t + EJust x e -> EJust x (subst' f w e) + EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (subst' f w e) + ELInr x t e -> ELInr x t (subst' f w e) + ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) + EConstArr x n t a -> EConstArr x n t a + EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + ESum1Inner x e -> ESum1Inner x (subst' f w e) + EUnit x e -> EUnit x (subst' f w e) + EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) + EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) + EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) + EZip x a b -> EZip x (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EConst x t v -> EConst x t v + EIdx0 x e -> EIdx0 x (subst' f w e) + EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) + EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) + EShape x e -> EShape x (subst' f w e) + EOp x op e -> EOp x op (subst' f w e) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) + ERecompute x e -> ERecompute x (subst' f w e) + EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) + EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) + EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EError x t s -> EError x t s + where + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF f' x' t w' = \case + IZ -> EVar x' t (w' @> IZ) + IS i -> f' x' t (WPop w') i + +weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) + +class KnownScalTy t where knownScalTy :: SScalTy t +instance KnownScalTy TI32 where knownScalTy = STI32 +instance KnownScalTy TI64 where knownScalTy = STI64 +instance KnownScalTy TF32 where knownScalTy = STF32 +instance KnownScalTy TF64 where knownScalTy = STF64 +instance KnownScalTy TBool where knownScalTy = STBool + +class KnownTy t where knownTy :: STy t +instance KnownTy TNil where knownTy = STNil +instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy +instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy +instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy + +class KnownEnv env where knownEnv :: SList STy env +instance KnownEnv '[] where knownEnv = SNil +instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +styKnown :: STy t -> Dict (KnownTy t) +styKnown STNil = Dict +styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict +styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict +styKnown (STScal t) | Dict <- sscaltyKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict + +sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) +sscaltyKnown STI32 = Dict +sscaltyKnown STI64 = Dict +sscaltyKnown STF32 = Dict +sscaltyKnown STF64 = Dict +sscaltyKnown STBool = Dict + +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict + +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e + _ -> False + +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup = mkTup (ENil ext) (EPair ext) + +ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) +ebuildUp1 n sh size f = + EBuild ext (SS n) (EPair ext sh size) $ + let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ + in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) + (EFst ext arg) + +eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eidxEq SZ _ _ = EConst ext STBool True +eidxEq (SS SZ) a b = + EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) +eidxEq (SS n) a b + | let ty = tTup (sreplicate (SS n) tIx) + = ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EOp ext OAnd $ EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) + (ESnd ext (EVar ext ty IZ)))) + (eidxEq n (EFst ext (EVar ext ty (IS IZ))) + (EFst ext (EVar ext ty IZ))) + +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr _ t <- typeOf arr + , Dict <- styKnown t + = EMap ext f arr + +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr _ t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) + IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) + IS (IS i) -> EVar ext t (IS i)) + f) + (EZip ext arr1 arr2) + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip = EZip ext + +eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a +eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) + +-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). +eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eshapeEmpty SZ _ = EConst ext STBool False +eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) +eshapeEmpty (SS n) e = + ELet ext e $ + EOp ext OAnd (EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) + (EConst ext STI64 0))) + (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + +eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) +eshapeConst ShNil = ENil ext +eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) + +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi + +-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil +-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea + +-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) +-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body + +-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +ecase e a b + | STEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ECase ext e a b + +elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext + +splitSparsePair + :: -- given a sparsity + STy (TPair a b) -> Sparse (TPair a b) t' + -> (forall a' b'. + -- I give you back two sparsities for a and b + Sparse a a' -> Sparse b b' + -- furthermore, I tell you that either your t' is already this (a', b') pair... + -> Either + (t' :~: TPair a' b') + -- or I tell you how to construct a' and b' from t', given an actual t' + (forall r' env. + Idx env t' + -> (forall env'. + (forall c. Ex env' c -> Ex env c) + -> Ex env' a' -> Ex env' b' -> r') + -> r') + -> r) + -> r +splitSparsePair _ SpAbsent k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = + k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = + let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in + k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> + k2 (elet $ + emaybe (EVar ext (STMaybe (applySparse s t)) i) + (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) + (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) + (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = + splitSparsePair t (SpSparse s) $ \s1 s2 eres -> + k s1 s2 $ Right $ \i k2 -> + case eres of + Left refl -> case refl of {} + Right f -> + f IZ $ \wrap e1 e2 -> + k2 (\body -> + elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) + (ENothing ext (applySparse s t)) + (evar IZ)) $ + wrap body) + e1 e2 -- cgit v1.2.3-70-g09d2