aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/AST.hs705
-rw-r--r--src/CHAD/AST/Accum.hs137
-rw-r--r--src/CHAD/AST/Bindings.hs84
-rw-r--r--src/CHAD/AST/Count.hs930
-rw-r--r--src/CHAD/AST/Env.hs95
-rw-r--r--src/CHAD/AST/Pretty.hs525
-rw-r--r--src/CHAD/AST/Sparse.hs287
-rw-r--r--src/CHAD/AST/Sparse/Types.hs107
-rw-r--r--src/CHAD/AST/SplitLets.hs191
-rw-r--r--src/CHAD/AST/Types.hs215
-rw-r--r--src/CHAD/AST/UnMonoid.hs255
-rw-r--r--src/CHAD/AST/Weaken.hs138
-rw-r--r--src/CHAD/AST/Weaken/Auto.hs192
-rw-r--r--src/CHAD/Analysis/Identity.hs436
-rw-r--r--src/CHAD/Array.hs131
-rw-r--r--src/CHAD/Compile.hs1796
-rw-r--r--src/CHAD/Compile/Exec.hs99
-rw-r--r--src/CHAD/Data.hs192
-rw-r--r--src/CHAD/Data/VarMap.hs119
-rw-r--r--src/CHAD/Drev.hs1583
-rw-r--r--src/CHAD/Drev/Accum.hs (renamed from src/CHAD/Accum.hs)12
-rw-r--r--src/CHAD/Drev/EnvDescr.hs (renamed from src/CHAD/EnvDescr.hs)14
-rw-r--r--src/CHAD/Drev/Top.hs (renamed from src/CHAD/Top.hs)26
-rw-r--r--src/CHAD/Drev/Types.hs (renamed from src/CHAD/Types.hs)8
-rw-r--r--src/CHAD/Drev/Types/ToTan.hs (renamed from src/CHAD/Types/ToTan.hs)14
-rw-r--r--src/CHAD/Example.hs197
-rw-r--r--src/CHAD/Example/GMM.hs124
-rw-r--r--src/CHAD/Example/Types.hs11
-rw-r--r--src/CHAD/ForwardAD.hs270
-rw-r--r--src/CHAD/ForwardAD/DualNumbers.hs231
-rw-r--r--src/CHAD/ForwardAD/DualNumbers/Types.hs48
-rw-r--r--src/CHAD/Interpreter.hs471
-rw-r--r--src/CHAD/Interpreter/Accum.hs366
-rw-r--r--src/CHAD/Interpreter/AccumOld.hs366
-rw-r--r--src/CHAD/Interpreter/Rep.hs105
-rw-r--r--src/CHAD/Language.hs266
-rw-r--r--src/CHAD/Language/AST.hs300
-rw-r--r--src/CHAD/Lemmas.hs21
-rw-r--r--src/CHAD/Simplify.hs619
-rw-r--r--src/CHAD/Simplify/TH.hs80
-rw-r--r--src/CHAD/Util/IdGen.hs19
41 files changed, 11748 insertions, 37 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
new file mode 100644
index 0000000..aa6aa96
--- /dev/null
+++ b/src/CHAD/AST.hs
@@ -0,0 +1,705 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where
+
+import Data.Functor.Const
+import Data.Functor.Identity
+import Data.Int (Int64)
+import Data.Kind (Type)
+
+import CHAD.Array
+import CHAD.AST.Accum
+import CHAD.AST.Sparse.Types
+import CHAD.AST.Types
+import CHAD.AST.Weaken
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+-- General assumption: head of the list (whatever way it is associated) is the
+-- inner variable / inner array dimension. In pretty printing, the inner
+-- variable / inner dimension is printed on the _right_.
+--
+-- All the monoid operations are unsupposed as the input to CHAD, and are
+-- intended to be eliminated after simplification, so that the input program as
+-- well as the output program do not contain these constructors.
+-- TODO: ensure this by a "stage" type parameter.
+type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data Expr x env t where
+ -- lambda calculus
+ EVar :: x t -> STy t -> Idx env t -> Expr x env t
+ ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
+
+ -- base types
+ EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
+ EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
+ ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
+ ENil :: x TNil -> Expr x env TNil
+ EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
+ EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
+ ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+ ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
+ EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
+ EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
+
+ -- array operations
+ EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
+ EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
+ EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
+ -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
+ EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
+ EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
+ EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
+ EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
+
+ -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
+ -- an implementation is allowed to parallelise this thing and store the b
+ -- values in some implementation-defined order.
+ -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
+ -> Expr x (TPair t1 t1 : env) (TPair t1 b)
+ -> Expr x env t1
+ -> Expr x env (TArr (S n) t1)
+ -> Expr x env (TPair (TArr n t1) -- normal primal fold output
+ (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
+ -- Reverse derivative of EFold1Inner. The contributions to the initial
+ -- element are not yet added together here; we assume a later fusion system
+ -- does that for us.
+ EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
+ -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
+ -> Expr x env (TArr n t2) -- incoming cotangent
+ -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
+
+ -- expression operations
+ EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
+ EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
+ EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
+ EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
+ EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
+ EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
+
+ -- custom derivatives
+ -- 'b' is the part of the input of the operation that derivatives should
+ -- be backpropagated to; 'a' is the inactive part. The dual field of
+ -- ECustom does not allow a derivative to be generated for 'a', and hence
+ -- none is propagated.
+ -- No accumulators are allowed inside a, b and tape. This restriction is
+ -- currently not used very much, so could be relaxed in the future; be sure
+ -- to check this requirement whenever it is necessary for soundness!
+ ECustom :: x t -> STy a -> STy b -> STy tape
+ -> Expr x [b, a] t -- ^ regular operation
+ -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr x env a -> Expr x env b
+ -> Expr x env t
+
+ -- fake halfway checkpointing
+ ERecompute :: x t -> Expr x env t -> Expr x env t
+
+ -- accumulation effect on monoids
+ -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
+ -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
+ -- need to create any zeros.
+ EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ -- The 'Sparse' here is eliminated to dense by UnMonoid.
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
+
+ -- monoidal operations (to be desugared to regular operations after simplification)
+ EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
+ EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
+ EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+
+ -- interface of abstract monoidal types
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
+ ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+
+ -- partiality
+ EError :: x a -> STy a -> String -> Expr x env a
+deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+
+type Ex = Expr (Const ())
+
+ext :: Const () a
+ext = Const ()
+
+data Commutative = Commut | Noncommut
+ deriving (Show)
+
+type SOp :: Ty -> Ty -> Type
+data SOp a t where
+ OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ ONot :: SOp (TScal TBool) (TScal TBool)
+ OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
+ OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
+ OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right
+ ORound64 :: SOp (TScal TF64) (TScal TI64)
+ OToFl64 :: SOp (TScal TI64) (TScal TF64)
+ ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+deriving instance Show (SOp a t)
+
+opt1 :: SOp a t -> STy a
+opt1 = \case
+ OAdd t -> STPair (STScal t) (STScal t)
+ OMul t -> STPair (STScal t) (STScal t)
+ ONeg t -> STScal t
+ OLt t -> STPair (STScal t) (STScal t)
+ OLe t -> STPair (STScal t) (STScal t)
+ OEq t -> STPair (STScal t) (STScal t)
+ ONot -> STScal STBool
+ OAnd -> STPair (STScal STBool) (STScal STBool)
+ OOr -> STPair (STScal STBool) (STScal STBool)
+ OIf -> STScal STBool
+ ORound64 -> STScal STF64
+ OToFl64 -> STScal STI64
+ ORecip t -> STScal t
+ OExp t -> STScal t
+ OLog t -> STScal t
+ OIDiv t -> STPair (STScal t) (STScal t)
+ OMod t -> STPair (STScal t) (STScal t)
+
+opt2 :: SOp a t -> STy t
+opt2 = \case
+ OAdd t -> STScal t
+ OMul t -> STScal t
+ ONeg t -> STScal t
+ OLt _ -> STScal STBool
+ OLe _ -> STScal STBool
+ OEq _ -> STScal STBool
+ ONot -> STScal STBool
+ OAnd -> STScal STBool
+ OOr -> STScal STBool
+ OIf -> STEither STNil STNil
+ ORound64 -> STScal STI64
+ OToFl64 -> STScal STF64
+ ORecip t -> STScal t
+ OExp t -> STScal t
+ OLog t -> STScal t
+ OIDiv t -> STScal t
+ OMod t -> STScal t
+
+typeOf :: Expr x env t -> STy t
+typeOf = \case
+ EVar _ t _ -> t
+ ELet _ _ e -> typeOf e
+
+ EPair _ a b -> STPair (typeOf a) (typeOf b)
+ EFst _ e | STPair t _ <- typeOf e -> t
+ ESnd _ e | STPair _ t <- typeOf e -> t
+ ENil _ -> STNil
+ EInl _ t2 e -> STEither (typeOf e) t2
+ EInr _ t1 e -> STEither t1 (typeOf e)
+ ECase _ _ a _ -> typeOf a
+ ENothing _ t -> STMaybe t
+ EJust _ e -> STMaybe (typeOf e)
+ EMaybe _ e _ _ -> typeOf e
+ ELNil _ t1 t2 -> STLEither t1 t2
+ ELInl _ t2 e -> STLEither (typeOf e) t2
+ ELInr _ t1 e -> STLEither t1 (typeOf e)
+ ELCase _ _ a _ _ -> typeOf a
+
+ EConstArr _ n t _ -> STArr n (STScal t)
+ EBuild _ n _ e -> STArr n (typeOf e)
+ EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)
+ EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EUnit _ e -> STArr SZ (typeOf e)
+ EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
+ EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)
+
+ EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
+ EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
+
+ EConst _ t _ -> STScal t
+ EIdx0 _ e | STArr _ t <- typeOf e -> t
+ EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
+ EIdx _ e _ | STArr _ t <- typeOf e -> t
+ EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
+ EOp _ op _ -> opt2 op
+
+ ECustom _ _ _ _ e _ _ _ _ -> typeOf e
+ ERecompute _ e -> typeOf e
+
+ EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
+ EAccum _ _ _ _ _ _ _ -> STNil
+
+ EZero _ t _ -> fromSMTy t
+ EDeepZero _ t _ -> fromSMTy t
+ EPlus _ t _ _ -> fromSMTy t
+ EOneHot _ t _ _ _ -> fromSMTy t
+
+ EError _ t _ -> t
+
+extOf :: Expr x env t -> x t
+extOf = \case
+ EVar x _ _ -> x
+ ELet x _ _ -> x
+ EPair x _ _ -> x
+ EFst x _ -> x
+ ESnd x _ -> x
+ ENil x -> x
+ EInl x _ _ -> x
+ EInr x _ _ -> x
+ ECase x _ _ _ -> x
+ ENothing x _ -> x
+ EJust x _ -> x
+ EMaybe x _ _ _ -> x
+ ELNil x _ _ -> x
+ ELInl x _ _ -> x
+ ELInr x _ _ -> x
+ ELCase x _ _ _ _ -> x
+ EConstArr x _ _ _ -> x
+ EBuild x _ _ _ -> x
+ EMap x _ _ -> x
+ EFold1Inner x _ _ _ _ -> x
+ ESum1Inner x _ -> x
+ EUnit x _ -> x
+ EReplicate1Inner x _ _ -> x
+ EMaximum1Inner x _ -> x
+ EMinimum1Inner x _ -> x
+ EReshape x _ _ _ -> x
+ EZip x _ _ -> x
+ EFold1InnerD1 x _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ -> x
+ EConst x _ _ -> x
+ EIdx0 x _ -> x
+ EIdx1 x _ _ -> x
+ EIdx x _ _ -> x
+ EShape x _ -> x
+ EOp x _ _ -> x
+ ECustom x _ _ _ _ _ _ _ _ -> x
+ ERecompute x _ -> x
+ EWith x _ _ _ -> x
+ EAccum x _ _ _ _ _ _ -> x
+ EZero x _ _ -> x
+ EDeepZero x _ _ -> x
+ EPlus x _ _ _ -> x
+ EOneHot x _ _ _ _ -> x
+ EError x _ _ -> x
+
+mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt f = runIdentity . travExt (Identity . f)
+
+{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
+travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt f = \case
+ EVar x t i -> EVar <$> f x <*> pure t <*> pure i
+ ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
+ EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b
+ EFst x e -> EFst <$> f x <*> travExt f e
+ ESnd x e -> ESnd <$> f x <*> travExt f e
+ ENil x -> ENil <$> f x
+ EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e
+ EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e
+ ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b
+ ENothing x t -> ENothing <$> f x <*> pure t
+ EJust x e -> EJust <$> f x <*> travExt f e
+ EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e
+ ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2
+ ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e
+ ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e
+ ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
+ EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
+ EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b
+ EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
+ EUnit x e -> EUnit <$> f x <*> travExt f e
+ EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
+ EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
+ EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b
+ EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EConst x t v -> EConst <$> f x <*> pure t <*> pure v
+ EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
+ EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
+ EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es
+ EShape x e -> EShape <$> f x <*> travExt f e
+ EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
+ ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
+ ERecompute x e -> ERecompute <$> f x <*> travExt f e
+ EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
+ EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3
+ EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
+ EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e
+ EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
+ EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
+ EError x t s -> EError <$> f x <*> pure t <*> pure s
+
+substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+substInline repl =
+ subst $ \x t -> \case IZ -> repl
+ IS i -> EVar x t i
+
+subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t
+subst0 repl =
+ subst $ \_ t -> \case IZ -> repl
+ IS i -> EVar ext t (IS i)
+
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
+ -> Expr x env t -> Expr x env' t
+subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
+
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+ -> env' :> envOut
+ -> Expr x env t
+ -> Expr x envOut t
+subst' f w = \case
+ EVar x t i -> f x t w i
+ ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
+ EPair x a b -> EPair x (subst' f w a) (subst' f w b)
+ EFst x e -> EFst x (subst' f w e)
+ ESnd x e -> ESnd x (subst' f w e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (subst' f w e)
+ EInr x t e -> EInr x t (subst' f w e)
+ ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ ENothing x t -> ENothing x t
+ EJust x e -> EJust x (subst' f w e)
+ EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (subst' f w e)
+ ELInr x t e -> ELInr x t (subst' f w e)
+ ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
+ EConstArr x n t a -> EConstArr x n t a
+ EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)
+ EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
+ ESum1Inner x e -> ESum1Inner x (subst' f w e)
+ EUnit x e -> EUnit x (subst' f w e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
+ EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
+ EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
+ EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
+ EZip x a b -> EZip x (subst' f w a) (subst' f w b)
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EConst x t v -> EConst x t v
+ EIdx0 x e -> EIdx0 x (subst' f w e)
+ EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
+ EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
+ EShape x e -> EShape x (subst' f w e)
+ EOp x op e -> EOp x op (subst' f w e)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
+ ERecompute x e -> ERecompute x (subst' f w e)
+ EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3)
+ EZero x t e -> EZero x t (subst' f w e)
+ EDeepZero x t e -> EDeepZero x t (subst' f w e)
+ EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
+ EError x t s -> EError x t s
+ where
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF f' x' t w' = \case
+ IZ -> EVar x' t (w' @> IZ)
+ IS i -> f' x' t (WPop w') i
+
+weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
+weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
+
+class KnownScalTy t where knownScalTy :: SScalTy t
+instance KnownScalTy TI32 where knownScalTy = STI32
+instance KnownScalTy TI64 where knownScalTy = STI64
+instance KnownScalTy TF32 where knownScalTy = STF32
+instance KnownScalTy TF64 where knownScalTy = STF64
+instance KnownScalTy TBool where knownScalTy = STBool
+
+class KnownTy t where knownTy :: STy t
+instance KnownTy TNil where knownTy = STNil
+instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
+instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
+instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
+instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
+instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+
+class KnownMTy t where knownMTy :: SMTy t
+instance KnownMTy TNil where knownMTy = SMTNil
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
+instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
+instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
+instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
+
+class KnownEnv env where knownEnv :: SList STy env
+instance KnownEnv '[] where knownEnv = SNil
+instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+
+styKnown :: STy t -> Dict (KnownTy t)
+styKnown STNil = Dict
+styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STMaybe t) | Dict <- styKnown t = Dict
+styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
+styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
+styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+
+smtyKnown :: SMTy t -> Dict (KnownMTy t)
+smtyKnown SMTNil = Dict
+smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
+smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
+smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
+
+sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
+sscaltyKnown STI32 = Dict
+sscaltyKnown STI64 = Dict
+sscaltyKnown STF32 = Dict
+sscaltyKnown STF64 = Dict
+sscaltyKnown STBool = Dict
+
+envKnown :: SList STy env -> Dict (KnownEnv env)
+envKnown SNil = Dict
+envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
+
+cheapExpr :: Expr x env t -> Bool
+cheapExpr = \case
+ EVar{} -> True
+ ENil{} -> True
+ EConst{} -> True
+ EFst _ e -> cheapExpr e
+ ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
+ _ -> False
+
+eTup :: SList (Ex env) list -> Ex env (Tup list)
+eTup = mkTup (ENil ext) (EPair ext)
+
+ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
+ebuildUp1 n sh size f =
+ EBuild ext (SS n) (EPair ext sh size) $
+ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
+ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
+ (EFst ext arg)
+
+eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eidxEq SZ _ _ = EConst ext STBool True
+eidxEq (SS SZ) a b =
+ EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b))
+eidxEq (SS n) a b
+ | let ty = tTup (sreplicate (SS n) tIx)
+ = ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EOp ext OAnd $ EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ)))
+ (ESnd ext (EVar ext ty IZ))))
+ (eidxEq n (EFst ext (EVar ext ty (IS IZ)))
+ (EFst ext (EVar ext ty IZ)))
+
+emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
+emap f arr
+ | STArr _ t <- typeOf arr
+ , Dict <- styKnown t
+ = EMap ext f arr
+
+ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
+ezipWith f arr1 arr2
+ | STArr _ t1 <- typeOf arr1
+ , STArr _ t2 <- typeOf arr2
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ)
+ IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ)
+ IS (IS i) -> EVar ext t (IS i))
+ f)
+ (EZip ext arr1 arr2)
+
+ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
+ezip = EZip ext
+
+eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
+eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
+
+-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
+eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eshapeEmpty SZ _ = EConst ext STBool False
+eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
+eshapeEmpty (SS n) e =
+ ELet ext e $
+ EOp ext OAnd (EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
+ (EConst ext STI64 0)))
+ (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
+eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
+eshapeConst ShNil = ENil ext
+eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
+
+eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx
+eshapeProd SZ _ = EConst ext STI64 1
+eshapeProd (SS SZ) e = ESnd ext e
+eshapeProd (SS n) e =
+ eunPair e $ \_ e1 e2 ->
+ EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2)
+
+eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t)
+eflatten e =
+ let STArr n _ = typeOf e
+ in elet e $
+ EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ)
+
+-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
+-- ezeroD2 t ezi = EZero ext (d2M t) ezi
+
+-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
+-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
+
+-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
+-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev
+
+eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
+eunPair (EPair _ e1 e2) k = k WId e1 e2
+eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e)
+eunPair e k =
+ elet e $
+ k WSink
+ (EFst ext (evar IZ))
+ (ESnd ext (evar IZ))
+
+efst :: Ex env (TPair a b) -> Ex env a
+efst (EPair _ e1 _) = e1
+efst e = EFst ext e
+
+esnd :: Ex env (TPair a b) -> Ex env b
+esnd (EPair _ _ e2) = e2
+esnd e = ESnd ext e
+
+elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
+elet rhs body
+ | Dict <- styKnown (typeOf rhs)
+ = if cheapExpr rhs
+ then substInline rhs body
+ else ELet ext rhs body
+
+-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
+emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
+emaybe e a b
+ | STMaybe t <- typeOf e
+ , Dict <- styKnown t
+ = EMaybe ext a b e
+
+ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+ecase e a b
+ | STEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ECase ext e a b
+
+elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+elcase e a b c
+ | STLEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELCase ext e a b c
+
+evar :: KnownTy a => Idx env a -> Ex env a
+evar = EVar ext knownTy
+
+makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
+ where
+ -- invariant: expression argument is duplicable
+ go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+ go SMTNil _ = ENil ext
+ go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
+ go SMTLEither{} _ = ENil ext
+ go SMTMaybe{} _ = ENil ext
+ go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
+ go SMTScal{} _ = ENil ext
+
+splitSparsePair
+ :: -- given a sparsity
+ STy (TPair a b) -> Sparse (TPair a b) t'
+ -> (forall a' b'.
+ -- I give you back two sparsities for a and b
+ Sparse a a' -> Sparse b b'
+ -- furthermore, I tell you that either your t' is already this (a', b') pair...
+ -> Either
+ (t' :~: TPair a' b')
+ -- or I tell you how to construct a' and b' from t', given an actual t'
+ (forall r' env.
+ Idx env t'
+ -> (forall env'.
+ (forall c. Ex env' c -> Ex env c)
+ -> Ex env' a' -> Ex env' b' -> r')
+ -> r')
+ -> r)
+ -> r
+splitSparsePair _ SpAbsent k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+splitSparsePair _ (SpPair s1 s2) k1 =
+ k1 s1 s2 $ Left Refl
+splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k =
+ let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in
+ k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 ->
+ k2 (elet $
+ emaybe (EVar ext (STMaybe (applySparse s t)) i)
+ (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2)))
+ (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ)))))
+ (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ))
+
+splitSparsePair _ (SpSparse SpAbsent) k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+-- -- TODO: having to handle sparse-of-sparse at all is ridiculous
+splitSparsePair t (SpSparse (SpSparse s)) k =
+ splitSparsePair t (SpSparse s) $ \s1 s2 eres ->
+ k s1 s2 $ Right $ \i k2 ->
+ case eres of
+ Left refl -> case refl of {}
+ Right f ->
+ f IZ $ \wrap e1 e2 ->
+ k2 (\body ->
+ elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i)
+ (ENothing ext (applySparse s t))
+ (evar IZ)) $
+ wrap body)
+ e1 e2
diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs
new file mode 100644
index 0000000..ea74a95
--- /dev/null
+++ b/src/CHAD/AST/Accum.hs
@@ -0,0 +1,137 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST.Accum where
+
+import CHAD.AST.Types
+import CHAD.Data
+
+
+data AcPrj
+ = APHere
+ | APFst AcPrj
+ | APSnd AcPrj
+ | APLeft AcPrj
+ | APRight AcPrj
+ | APJust AcPrj
+ | APArrIdx AcPrj
+ | APArrSlice Nat
+
+-- | @b@ is a small part of @a@, indicated by the projection @p@.
+data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
+ SAPHere :: SAcPrj APHere a a
+ SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
+ SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
+ SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
+ SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
+ SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
+ -- TODO:
+ -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
+deriving instance Show (SAcPrj p a b)
+
+type data AIDense = AID | AIS
+
+data SAIDense d where
+ SAID :: SAIDense AID
+ SAIS :: SAIDense AIS
+deriving instance Show (SAIDense d)
+
+type family AcIdx d p t where
+ AcIdx d APHere t = TNil
+ AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
+ AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
+ AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
+ AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
+ AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
+ AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
+ AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
+ AcIdx AID (APArrIdx p) (TArr n a) =
+ -- (index, recursive info)
+ TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
+ AcIdx AIS (APArrIdx p) (TArr n a) =
+ -- ((index, shape info), recursive info)
+ TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
+ (AcIdx AIS p a)
+ -- AcIdx AID (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AIS (APArrSlice m) (TArr n a) =
+ -- -- (index, array shape)
+ -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+
+type AcIdxD p t = AcIdx AID p t
+type AcIdxS p t = AcIdx AIS p t
+
+acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
+acPrjTy SAPHere t = t
+acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
+acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
+acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
+acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
+acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
+acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
+
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Should be constructable from a D1.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+
+-- -- | Additional info needed for accumulation. This is empty unless there is
+-- -- sparsity in the monoid.
+-- type family AccumInfo t where
+-- AccumInfo TNil = TNil
+-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
+-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
+-- AccumInfo (TArr n t) = TArr n (AccumInfo t)
+-- AccumInfo (TScal t) = TNil
+
+-- type family PrimalInfo t where
+-- PrimalInfo TNil = TNil
+-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
+-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
+-- PrimalInfo (TScal t) = TNil
+
+-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
+-- tPrimalInfo SMTNil = STNil
+-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
+-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
+-- tPrimalInfo (SMTScal _) = STNil
diff --git a/src/CHAD/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs
new file mode 100644
index 0000000..c1a1e77
--- /dev/null
+++ b/src/CHAD/AST/Bindings.hs
@@ -0,0 +1,84 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module CHAD.AST.Bindings where
+
+import CHAD.AST
+import CHAD.AST.Env
+import CHAD.Data
+import CHAD.Lemmas
+
+
+-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'.
+data Bindings f env binds where
+ BTop :: Bindings f env '[]
+ BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)
+deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
+infixl `BPush`
+
+bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds)
+bpush b e = b `BPush` (typeOf e, e)
+infixl `bpush`
+
+mapBindings :: (forall env' t'. f env' t' -> g env' t')
+ -> Bindings f env binds -> Bindings g env binds
+mapBindings _ BTop = BTop
+mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e)
+
+weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+ -> env1 :> env2
+ -> Bindings f env1 binds
+ -> (Bindings f env2 binds, Append binds env1 :> Append binds env2)
+weakenBindings _ w BTop = (BTop, w)
+weakenBindings wf w (BPush b (t, x)) =
+ let (b', w') = weakenBindings wf w b
+ in (BPush b' (t, wf w' x), WCopy w')
+
+weakenBindingsE :: env1 :> env2
+ -> Bindings (Expr x) env1 binds
+ -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2)
+weakenBindingsE = weakenBindings weakenExpr
+
+weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
+weakenOver SNil w = w
+weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
+
+sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env'
+sinkWithBindings BTop = WId
+sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
+
+bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1)
+bconcat b1 BTop = b1
+bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x))
+ | Refl <- lemAppendAssoc @binds2C @binds1 @env
+ = BPush (bconcat b1 b2) (t, x)
+
+bindingsBinds :: Bindings f env binds -> SList STy binds
+bindingsBinds BTop = SNil
+bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
+
+letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
+letBinds BTop = id
+letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
+collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
+collectBindings = \env -> fst . go env WId
+ where
+ go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
+ go _ _ SETop = (BTop, WId)
+ go (ty `SCons` env) w (SEYesR sub) =
+ let (bs, w') = go env (WPop w) sub
+ in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
+ go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs
new file mode 100644
index 0000000..133093a
--- /dev/null
+++ b/src/CHAD/AST/Count.hs
@@ -0,0 +1,930 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE PatternSynonyms #-}
+module CHAD.AST.Count where
+
+import Data.Functor.Product
+import Data.Some
+import Data.Type.Equality
+import GHC.Generics (Generic, Generically(..))
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Env
+import CHAD.Data
+
+
+-- | The monoid operation combines assuming that /both/ branches are taken.
+class Monoid a => Occurrence a where
+ -- | One of the two branches is taken
+ (<||>) :: a -> a -> a
+ -- | This code is executed many times
+ scaleMany :: a -> a
+
+
+data Count = Zero | One | Many
+ deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+ Zero <> n = n
+ n <> Zero = n
+ _ <> _ = Many
+instance Monoid Count where
+ mempty = Zero
+instance Occurrence Count where
+ (<||>) = max
+ scaleMany Zero = Zero
+ scaleMany _ = Many
+
+data Occ = Occ { _occLexical :: Count
+ , _occRuntime :: Count }
+ deriving (Eq, Generic)
+ deriving (Semigroup, Monoid) via Generically Occ
+
+instance Show Occ where
+ showsPrec d (Occ l r) = showParen (d > 10) $
+ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
+
+instance Occurrence Occ where
+ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2)
+ scaleMany (Occ l c) = Occ l (scaleMany c)
+
+
+data Substruc t t' where
+ -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below
+ SsFull :: Substruc t t
+ SsNone :: Substruc t TNil
+ SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b')
+ SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b')
+ SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b')
+ SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a')
+ SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements
+ SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a')
+
+pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t'
+pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2)
+ where SsPair' = SsPair
+{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-}
+
+pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t'
+pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s)
+ where SsArr' = SsArr
+{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-}
+
+instance Semigroup (Some (Substruc t)) where
+ Some SsFull <> _ = Some SsFull
+ _ <> Some SsFull = Some SsFull
+ Some SsNone <> s = s
+ s <> Some SsNone = s
+ Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2)
+ Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2)
+ Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2)
+ Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2)
+ Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2)
+ Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2)
+instance Monoid (Some (Substruc t)) where
+ mempty = Some SsNone
+
+instance TestEquality (Substruc t) where
+ testEquality SsFull s = isFull s
+ testEquality s SsFull = sym <$> isFull s
+ testEquality SsNone SsNone = Just Refl
+ testEquality SsNone _ = Nothing
+ testEquality _ SsNone = Nothing
+ testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+
+isFull :: Substruc t t' -> Maybe (t :~: t')
+isFull SsFull = Just Refl
+isFull SsNone = Nothing -- TODO: nil?
+isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+
+applySubstruc :: Substruc t t' -> STy t -> STy t'
+applySubstruc SsFull t = t
+applySubstruc SsNone _ = STNil
+applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t)
+applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t)
+applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t)
+
+applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t'
+applySubstrucM SsFull t = t
+applySubstrucM SsNone _ = SMTNil
+applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t)
+applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t)
+applySubstrucM _ t = case t of {}
+
+data ExMap a b = ExMap (forall env. Ex env a -> Ex env b)
+ | a ~ b => ExMapId
+
+fromExMap :: ExMap a b -> Ex env a -> Ex env b
+fromExMap (ExMap f) = f
+fromExMap ExMapId = id
+
+simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t'
+simplifySubstruc STNil SsNone = SsFull
+
+simplifySubstruc _ SsFull = SsFull
+simplifySubstruc _ SsNone = SsNone
+simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s)
+simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s)
+simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s)
+
+-- simplifySubstruc' :: Substruc t t'
+-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r
+-- simplifySubstruc' SsFull k = k SsFull ExMapId
+-- simplifySubstruc' SsNone k = k SsNone ExMapId
+-- simplifySubstruc' (SsPair s1 s2) k =
+-- simplifySubstruc' s1 $ \s1' f1 ->
+-- simplifySubstruc' s2 $ \s2' f2 ->
+-- case (s1', s2') of
+-- (SsFull, SsFull) ->
+-- k SsFull (case (f1, f2) of
+-- (ExMapId, ExMapId) -> ExMapId
+-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 ->
+-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2)))
+-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext))))
+-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ)))))
+-- simplifySubstruc' _ _ = _
+
+-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b)
+-- ssUnpair SsFull = (SsFull, SsFull)
+-- ssUnpair SsNone = (SsNone, SsNone)
+-- ssUnpair (SsPair a b) = (a, b)
+
+-- ssUnleft :: Substruc (TEither a b) -> Substruc a
+-- ssUnleft SsFull = SsFull
+-- ssUnleft SsNone = SsNone
+-- ssUnleft (SsEither a _) = a
+
+-- ssUnright :: Substruc (TEither a b) -> Substruc b
+-- ssUnright SsFull = SsFull
+-- ssUnright SsNone = SsNone
+-- ssUnright (SsEither _ b) = b
+
+-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a
+-- ssUnlleft SsFull = SsFull
+-- ssUnlleft SsNone = SsNone
+-- ssUnlleft (SsLEither a _) = a
+
+-- ssUnlright :: Substruc (TLEither a b) -> Substruc b
+-- ssUnlright SsFull = SsFull
+-- ssUnlright SsNone = SsNone
+-- ssUnlright (SsLEither _ b) = b
+
+-- ssUnjust :: Substruc (TMaybe a) -> Substruc a
+-- ssUnjust SsFull = SsFull
+-- ssUnjust SsNone = SsNone
+-- ssUnjust (SsMaybe a) = a
+
+-- ssUnarr :: Substruc (TArr n a) -> Substruc a
+-- ssUnarr SsFull = SsFull
+-- ssUnarr SsNone = SsNone
+-- ssUnarr (SsArr a) = a
+
+-- ssUnaccum :: Substruc (TAccum a) -> Substruc a
+-- ssUnaccum SsFull = SsFull
+-- ssUnaccum SsNone = SsNone
+-- ssUnaccum (SsAccum a) = a
+
+
+type family MapEmpty env where
+ MapEmpty '[] = '[]
+ MapEmpty (t : env) = TNil : MapEmpty env
+
+data OccEnv a env env' where
+ OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top!
+ OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env')
+
+instance Semigroup a => Semigroup (Some (OccEnv a env)) where
+ Some OccEnd <> e = e
+ e <> Some OccEnd = e
+ Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2)
+
+instance Semigroup a => Monoid (Some (OccEnv a env)) where
+ mempty = Some OccEnd
+
+instance Occurrence a => Occurrence (Some (OccEnv a env)) where
+ Some OccEnd <||> e = e
+ e <||> Some OccEnd = e
+ Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2)
+
+ scaleMany (Some OccEnd) = Some OccEnd
+ scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s)
+
+onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env)
+onehotOccEnv IZ v s = Some (OccPush OccEnd v s)
+onehotOccEnv (IS i) v s
+ | Some env' <- onehotOccEnv i v s
+ = Some (OccPush env' mempty SsNone)
+
+occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t')
+occEnvPop (OccPush e _ s) = (e, s)
+occEnvPop OccEnd = (OccEnd, SsNone)
+
+occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r
+occEnvPop' (OccPush e _ s) k = k e s
+occEnvPop' OccEnd k = k OccEnd SsNone
+
+occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env)
+occEnvPopSome (Some (OccPush e _ _)) = Some e
+occEnvPopSome (Some OccEnd) = Some OccEnd
+
+occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t))
+occEnvPrj OccEnd _ = mempty
+occEnvPrj (OccPush _ o s) IZ = (o, Some s)
+occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i
+
+occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env'))
+occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ)
+occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i'))
+occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ)
+occEnvPrjS (OccPush e _ _) (IS i)
+ | Some (Pair s' i') <- occEnvPrjS e i
+ = Some (Pair s' (IS i'))
+
+projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small
+projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
+ _ | Just Refl <- testEquality topsbig topssmall -> ex
+
+ (SsFull, SsFull) -> ex
+ (SsNone, SsNone) -> ex
+ (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller"
+ (_, SsNone) ->
+ case typeOf ex of
+ STNil -> ex
+ _ -> use ex $ ENil ext
+
+ (SsPair s1 s2, SsPair s1' s2') ->
+ eunPair ex $ \_ e1 e2 ->
+ EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2)
+ (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex
+ (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex
+
+ (SsEither s1 s2, SsEither s1' s2')
+ | STEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in ecase ex
+ (EInl ext (typeOf e2) e1)
+ (EInr ext (typeOf e1) e2)
+ (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex
+ (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex
+
+ (SsLEither s1 s2, SsLEither s1' s2')
+ | STLEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in elcase ex
+ (ELNil ext (typeOf e1) (typeOf e2))
+ (ELInl ext (typeOf e2) e1)
+ (ELInr ext (typeOf e1) e2)
+ (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex
+ (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex
+
+ (SsMaybe s1, SsMaybe s1')
+ | STMaybe t1 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ in emaybe ex
+ (ENothing ext (typeOf e1))
+ (EJust ext e1)
+ (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
+ (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
+
+ (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex
+ (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
+ (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
+
+ (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum"
+ (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex
+ (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex
+
+
+-- | A boolean for each entry in the environment, with the ability to uniformly
+-- mask the top part above a certain index.
+data EnvMask env where
+ EMRest :: Bool -> EnvMask env
+ EMPush :: EnvMask env -> Bool -> EnvMask (t : env)
+
+envMaskPrj :: EnvMask env -> Idx env t -> Bool
+envMaskPrj (EMRest b) _ = b
+envMaskPrj (_ `EMPush` b) IZ = b
+envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx ex
+ | Some env <- occCountAll ex
+ = fst (occEnvPrj env idx)
+
+occCountAll :: Expr x env t -> Some (OccEnv Occ env)
+occCountAll ex = occCountX SsFull ex $ \env _ -> Some env
+
+pruneExpr :: SList f env -> Expr x env t -> Ex env t
+pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
+ where
+ fullOccEnv :: SList f env -> OccEnv () env env
+ fullOccEnv SNil = OccEnd
+ fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
+
+-- In one traversal, count occurrences of variables and determine what parts of
+-- expressions are actually used. These two results are computed independently:
+-- even if (almost) nothing of a particular term is actually used, variable
+-- references in that term still count as usual.
+--
+-- In @occCountX s t k@:
+-- * s: how much of the result of this term is required
+-- * t: the term to analyse
+-- * k: is passed the actual environment usage of this expression, including
+-- occurrence counts. The callback reconstructs a new expression in an
+-- updated "response" environment. The response must be at least as large as
+-- the computed usages.
+occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t
+ -> (forall env'. OccEnv Occ env env'
+ -- response OccEnv must be at least as large as the OccEnv returned above
+ -> (forall env''. OccEnv () env env'' -> Ex env'' t')
+ -> r)
+ -> r
+occCountX initialS topexpr k = case topexpr of
+ EVar _ t i ->
+ withSome (onehotOccEnv i (Occ One One) s) $ \env ->
+ k env $ \env' ->
+ withSome (occEnvPrjS env' i) $ \(Pair s' i') ->
+ projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i')
+ ELet _ rhs body ->
+ occCountX s body $ \envB mkbody ->
+ occEnvPop' envB $ \envB' s1 ->
+ occCountX s1 rhs $ \envR mkrhs ->
+ withSome (Some envB' <> Some envR) $ \env ->
+ k env $ \env' ->
+ ELet ext (mkrhs env') (mkbody (OccPush env' () s1))
+ EPair _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsPair' s1 s2 ->
+ occCountX s1 a $ \env1 mka ->
+ occCountX s2 b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPair ext (mka env') (mkb env')
+ EFst _ e ->
+ occCountX (SsPair s SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EFst ext (mke env')
+ ESnd _ e ->
+ occCountX (SsPair SsNone s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ ESnd ext (mke env')
+ ENil _ ->
+ case s of
+ SsFull -> k OccEnd (\_ -> ENil ext)
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ EInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ EInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ ECase _ e a b ->
+ occCountX s a $ \env1' mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env1' $ \env1 s1 ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
+ k env $ \env' ->
+ ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
+ ENothing _ t ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t))
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EJust _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsMaybe s' ->
+ occCountX s' e $ \env1 mke ->
+ k env1 $ \env' ->
+ EJust ext (mke env')
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EMaybe _ a b e ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsMaybe s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
+ k env $ \env' ->
+ EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
+ ELNil _ t1 t2 ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2))
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELCase _ e a b c ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occCountX s c $ \env3' mkc ->
+ occEnvPop' env2' $ \env2 s1 ->
+ occEnvPop' env3' $ \env3 s2 ->
+ occCountX (SsLEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env ->
+ k env $ \env' ->
+ ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
+
+ EConstArr _ n t x ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
+
+ EBuild _ n a b ->
+ case s of
+ SsNone ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsNone b $ \env2'' mkb ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
+ k env $ \env' ->
+ use (EBuild ext n (mka env') $
+ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
+ ENil ext) $
+ ENil ext
+ SsArr' s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX s' b $ \env2'' mkb ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
+ k env $ \env' ->
+ EBuild ext n (mka env') $
+ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+
+ EMap _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
+ ENil ext
+ SsArr' s' ->
+ occCountX s' a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMap ext (mka (OccPush env' () s1)) (mkb env')
+
+ EFold1Inner _ commut a b c ->
+ occCountX SsFull a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1' ->
+ let s1 = case s1' of
+ SsNone -> Some SsNone
+ SsPair' s1'a s1'b -> Some s1'a <> Some s1'b
+ s0 = case s of
+ SsNone -> Some SsNone
+ SsArr' s' -> Some s' in
+ withSome (s1 <> s0) $ \sElt ->
+ occCountX sElt b $ \env2 mkb ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsArr sElt) s $
+ EFold1Inner ext commut
+ (projectSmallerSubstruc SsFull sElt $
+ mka (OccPush env' () (SsPair sElt sElt)))
+ (mkb env') (mkc env')
+
+ ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
+
+ EUnit _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX s' e $ \env mke ->
+ k env $ \env' ->
+ EUnit ext (mke env')
+
+ EReplicate1Inner _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX (SsArr s') b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReplicate1Inner ext (mka env') (mkb env')
+
+ EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+
+ EReshape _ n esh e ->
+ case s of
+ SsNone ->
+ occCountX SsNone esh $ \env1 mkesh ->
+ occCountX SsNone e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkesh env') $ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull esh $ \env1 mkesh ->
+ occCountX (SsArr s') e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReshape ext n (mkesh env') (mke env')
+
+ EZip _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $ mka env'
+ SsArr' (SsPair' SsNone s2) ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $
+ emap (EPair ext (ENil ext) (evar IZ)) (mkb env')
+ SsArr' (SsPair' s1 SsNone) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $
+ emap (EPair ext (evar IZ) (ENil ext)) (mka env')
+ SsArr' (SsPair' s1 s2) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EZip ext (mka env') (mkb env')
+
+ EFold1InnerD1 _ cm e1 e2 e3 ->
+ case s of
+ -- If nothing is necessary, we can execute a fold and then proceed to ignore it
+ SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex
+ -- If we don't need the stores, still a fold suffices
+ SsPair' sP SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext)
+ -- If for whatever reason the additional stores themselves are
+ -- unnecessary but the shape of the array is, then oblige
+ SsPair' sP (SsArr' SsNone) ->
+ let STArr sn _ = typeOf e3
+ foldex =
+ elet (mapExt (\_ -> ext) e3) $
+ EPair ext
+ (EShape ext (evar IZ))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1)))
+ (mapExt (\_ -> ext) (weakenExpr WSink e2))
+ (evar IZ))
+ in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
+ k env1 $ \env' ->
+ eunPair (mkfoldex env') $ \_ eshape earr ->
+ EPair ext earr (EBuild ext sn eshape (ENil ext))
+ -- If at least some of the additional stores are required, we need to keep this a mapAccum
+ SsPair' _ (SsArr' sB) ->
+ -- TODO: propagate usage of primals
+ occCountX (SsPair SsFull sB) e1 $ \env1_1' mka ->
+ occEnvPop' env1_1' $ \env1' _ ->
+ occCountX SsFull e2 $ \env2 mkb ->
+ occCountX SsFull e3 $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
+ EFold1InnerD1 ext cm (mka (OccPush env' () SsFull))
+ (mkb env') (mkc env')
+
+ EFold1InnerD2 _ cm ef ebog ed ->
+ -- TODO: propagate usage of duals
+ occCountX SsFull ef $ \env1_2' mkef ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' sB ->
+ occCountX (SsArr sB) ebog $ \env2 mkebog ->
+ occCountX SsFull ed $ \env3 mked ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD2 ext cm
+ (mkef (OccPush (OccPush env' () sB) () SsFull))
+ (mkebog env') (mked env')
+
+ EConst _ t x ->
+ k OccEnd $ \_ ->
+ case s of
+ SsNone -> ENil ext
+ SsFull -> EConst ext t x
+
+ EIdx0 _ e ->
+ occCountX (SsArr s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EIdx0 ext (mke env')
+
+ EIdx1 _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' s' ->
+ occCountX (SsArr s') a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx1 ext (mka env') (mkb env')
+
+ EIdx _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ _ ->
+ occCountX (SsArr s) a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx ext (mka env') (mkb env')
+
+ EShape _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX (SsArr SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EShape ext (mke env')
+
+ EOp _ op e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
+
+ ECustom _ t1 t2 t3 e1 e2 e3 a b
+ | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
+ error "Accumulators not allowed in input/output/tape of an ECustom"
+ | otherwise ->
+ case s of
+ SsNone ->
+ -- Allowed to ignore e1/e2/e3 here because no accumulators are
+ -- communicated, and hence no relevant effects exist
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ s' -> -- Let's be pessimistic for safety
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s' $
+ ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
+
+ ERecompute _ e ->
+ occCountX s e $ \env1 mke ->
+ k env1 $ \env' ->
+ ERecompute ext (mke env')
+
+ EWith _ t a b ->
+ case s of
+ SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with
+ occCountX SsNone b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ withSome (case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone) $ \s1' ->
+ occCountX s1' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $
+ ENil ext
+ SsPair sB sA ->
+ occCountX sB b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ let s1' = case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone in
+ withSome (Some sA <> s1') $ \sA' ->
+ occCountX sA' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $
+ EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA')))
+ SsFull -> occCountX (SsPair SsFull SsFull) topexpr k
+
+ EAccum _ t p a sp b e ->
+ -- TODO: do better!
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ occCountX SsFull e $ \env3 mke ->
+ withSome (Some env1 <> Some env2) $ \env12 ->
+ withSome (Some env12 <> Some env3) $ \env ->
+ k env $ \env' ->
+ case s of {SsFull -> id; SsNone -> id} $
+ EAccum ext t p (mka env') sp (mkb env') (mke env')
+
+ EZero _ t e ->
+ occCountX (subZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EZero ext (applySubstrucM s t) (mke env')
+ where
+ subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2)
+ subZeroInfo SsFull = SsFull
+ subZeroInfo SsNone = SsNone
+ subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2)
+ subZeroInfo SsEither{} = error "Either is not a monoid"
+ subZeroInfo SsLEither{} = SsNone
+ subZeroInfo SsMaybe{} = SsNone
+ subZeroInfo (SsArr s') = SsArr (subZeroInfo s')
+ subZeroInfo SsAccum{} = error "Accum is not a monoid"
+
+ EDeepZero _ t e ->
+ occCountX (subDeepZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EDeepZero ext (applySubstrucM s t) (mke env')
+ where
+ subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2)
+ subDeepZeroInfo SsFull = SsFull
+ subDeepZeroInfo SsNone = SsNone
+ subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo SsEither{} = error "Either is not a monoid"
+ subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s')
+ subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s')
+ subDeepZeroInfo SsAccum{} = error "Accum is not a monoid"
+
+ EPlus _ t a b ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPlus ext (applySubstrucM s t) (mka env') (mkb env')
+
+ EOneHot _ t p a b ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb -> -- TODO: do better
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env')
+
+ EError _ t msg ->
+ k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
+ where
+ s = simplifySubstruc (typeOf topexpr) initialS
+
+ handleReduction :: t ~ TArr n (TScal t2)
+ => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
+ -> Expr x env (TArr (S n) (TScal t2))
+ -> r
+ handleReduction reduce e
+ | STArr (SS n) _ <- typeOf e =
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr' SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ reduce (mke env')
+
+
+deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
+deleteUnused SNil (Some OccEnd) k = k SETop
+deleteUnused (_ `SCons` env) (Some OccEnd) k =
+ deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub)
+deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k =
+ deleteUnused env (Some occenv) $ \sub ->
+ case count of Zero -> k (SENo sub)
+ _ -> k (SEYesR sub)
+
+unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
+unsafeWeakenWithSubenv = \sub ->
+ subst (\x t i -> case sinkViaSubenv i sub of
+ Just i' -> EVar x t i'
+ Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
+ where
+ sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
+ sinkViaSubenv IZ (SEYesR _) = Just IZ
+ sinkViaSubenv IZ (SENo _) = Nothing
+ sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
+ sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs
new file mode 100644
index 0000000..8e6b745
--- /dev/null
+++ b/src/CHAD/AST/Env.hs
@@ -0,0 +1,95 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.AST.Env where
+
+import Data.Type.Equality
+
+import CHAD.AST.Sparse
+import CHAD.AST.Weaken
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+-- | @env'@ is a subset of @env@: each element of @env@ is either included in
+-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
+data Subenv' s env env' where
+ SETop :: Subenv' s '[] '[]
+ SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env')
+ SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env'
+deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env')
+
+type Subenv = Subenv' (:~:)
+type SubenvS = Subenv' Sparse
+
+pattern SEYesR :: forall tenv tenv'. ()
+ => forall t env env'. (tenv ~ t : env, tenv' ~ t : env')
+ => Subenv env env' -> Subenv tenv tenv'
+pattern SEYesR s = SEYes Refl s
+
+{-# COMPLETE SETop, SEYesR, SENo #-}
+
+subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'
+subList SNil SETop = SNil
+subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)
+subList (SCons _ xs) (SENo sub) = subList xs sub
+
+subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env
+subenvAll SNil = SETop
+subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env)
+
+subenvNone :: SList f env -> Subenv' s env '[]
+subenvNone SNil = SETop
+subenvNone (SCons _ env) = SENo (subenvNone env)
+
+subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t']
+subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
+subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
+subenvOnehot SNil i _ = case i of {}
+
+subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
+subenvCompose SETop SETop = SETop
+subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
+subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
+
+subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')
+subenvConcat sub1 SETop = sub1
+subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)
+subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
+
+-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2
+-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r
+-- subenvSplit SNil sub k = k SETop sub
+-- subenvSplit (SCons _ list) (SENo sub) k =
+-- subenvSplit list sub $ \sub1 sub2 ->
+-- k (SENo sub1) sub2
+-- subenvSplit (SCons _ list) (SEYes s sub) k =
+-- subenvSplit list sub $ \sub1 sub2 ->
+-- k (SEYes s sub1) sub2
+
+sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0
+sinkWithSubenv SETop = WId
+sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub
+sinkWithSubenv (SENo sub) = sinkWithSubenv sub
+
+wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env
+wUndoSubenv SETop = WId
+wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
+
+subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env'
+subenvMap _ SNil SETop = SETop
+subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub)
+subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub)
+
+subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env')
+subenvD2E SETop = SETop
+subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub)
+subenvD2E (SENo sub) = SENo (subenvD2E sub)
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
new file mode 100644
index 0000000..3f6a3af
--- /dev/null
+++ b/src/CHAD/AST/Pretty.hs
@@ -0,0 +1,525 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
+
+import Control.Monad (ap)
+import Data.List (intersperse, intercalate)
+import Data.Functor.Const
+import qualified Data.Functor.Product as Product
+import Data.String (fromString)
+import Prettyprinter
+import Prettyprinter.Render.String
+
+import qualified Data.Text.Lazy as TL
+import qualified Prettyprinter.Render.Terminal as PT
+import System.Console.ANSI (hSupportsANSI)
+import System.IO (stdout)
+import System.IO.Unsafe (unsafePerformIO)
+
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Sparse.Types
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+class PrettyX x where
+ prettyX :: x t -> String
+
+ prettyXsuffix :: x t -> String
+ prettyXsuffix x = "<" ++ prettyX x ++ ">"
+
+instance PrettyX (Const ()) where
+ prettyX _ = ""
+ prettyXsuffix _ = ""
+
+
+type SVal = SList (Const String)
+
+newtype M a = M { runM :: Int -> (a, Int) }
+ deriving (Functor)
+instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
+instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) }
+
+genId :: M Int
+genId = M (\i -> (i, i + 1))
+
+nameBaseForType :: STy t -> String
+nameBaseForType STNil = "nil"
+nameBaseForType (STPair{}) = "p"
+nameBaseForType (STEither{}) = "e"
+nameBaseForType (STMaybe{}) = "m"
+nameBaseForType (STScal STI32) = "n"
+nameBaseForType (STScal STI64) = "n"
+nameBaseForType (STArr{}) = "a"
+nameBaseForType (STAccum{}) = "ac"
+nameBaseForType _ = "x"
+
+genName' :: String -> M String
+genName' prefix = (prefix ++) . show <$> genId
+
+genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn' prefix ty idx ex
+ | occCount idx ex == mempty = case ty of STNil -> return "()"
+ _ -> return "_"
+ | otherwise = genName' prefix
+
+-- TODO: let this return a type-tagged thing so that name environments are more typed than Const
+genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
+
+pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
+pprintExpr = putStrLn . ppExpr knownEnv
+
+ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
+ppExpr senv e = render $ fst . flip runM 1 $ do
+ val <- mkVal senv
+ e' <- ppExpr' 0 val e
+ let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
+ return $ group $ flatAlt
+ (hang 2 $
+ ppString lam
+ <> hardline <> e')
+ (ppString lam <+> e')
+ where
+ mkVal :: SList f env -> M (SVal env)
+ mkVal SNil = return SNil
+ mkVal (SCons _ v) = do
+ val <- mkVal v
+ name <- genName' "arg"
+ return (Const name `SCons` val)
+
+ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
+ppExpr' d val expr = case expr of
+ EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr
+
+ e@ELet{} -> ppExprLet d val e
+
+ EPair _ a b -> do
+ a' <- ppExpr' 0 val a
+ b' <- ppExpr' 0 val b
+ return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr)
+ (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr)
+
+ EFst _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e'
+
+ ESnd _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e'
+
+ ENil _ -> return $ ppString "()"
+
+ EInl _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e'
+
+ EInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e'
+
+ ECase _ e a b -> do
+ e' <- ppExpr' 0 val e
+ let STEither t1 t2 = typeOf e
+ name1 <- genNameIfUsedIn t1 IZ a
+ a' <- ppExpr' 0 (Const name1 `SCons` val) a
+ name2 <- genNameIfUsedIn t2 IZ b
+ b' <- ppExpr' 0 (Const name2 `SCons` val) b
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a'
+ <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b'
+
+ ENothing _ _ -> return $ ppString "Nothing"
+
+ EJust _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e'
+
+ EMaybe _ a b e -> do
+ let STMaybe t = typeOf e
+ e' <- ppExpr' 0 val e
+ a' <- ppExpr' 0 val a
+ name <- genNameIfUsedIn t IZ b
+ b' <- ppExpr' 0 (Const name `SCons` val) b
+ return $ ppParen (d > 0) $
+ align $
+ group (flatAlt
+ (annotate AKey (ppString "case") <> ppX expr <+> e'
+ <> hardline <> annotate AKey (ppString "of"))
+ (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")))
+ <> hardline
+ <> indent 2
+ (ppString "Nothing" <+> ppString "->" <+> a'
+ <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b')
+
+ ELNil _ _ _ -> return (ppString "LNil")
+
+ ELInl _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e'
+
+ ELInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e'
+
+ ELCase _ e a b c -> do
+ e' <- ppExpr' 0 val e
+ let STLEither t1 t2 = typeOf e
+ a' <- ppExpr' 11 val a
+ name1 <- genNameIfUsedIn t1 IZ b
+ b' <- ppExpr' 0 (Const name1 `SCons` val) b
+ name2 <- genNameIfUsedIn t2 IZ c
+ c' <- ppExpr' 0 (Const name2 `SCons` val) c
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "LNil" <+> ppString "->" <+> a'
+ <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b'
+ <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c'
+
+ EConstArr _ _ ty v
+ | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
+
+ EBuild _ n a b -> do
+ a' <- ppExpr' 11 val a
+ name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
+ e' <- ppExpr' 0 (Const name `SCons` val) b
+ let primName = ppString ("build" ++ intSubscript (fromSNat n))
+ return $ ppParen (d > 0) $
+ group $ flatAlt
+ (hang 2 $
+ annotate AHighlight primName <> ppX expr <+> a'
+ <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
+ <> hardline <> e')
+ (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
+
+ EMap _ a b -> do
+ let STArr _ t1 = typeOf b
+ name <- genNameIfUsedIn t1 IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 0) $
+ ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
+
+ EFold1Inner _ cm a b c -> do
+ name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
+ let opname = "fold1i" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
+
+ ESum1Inner _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e'
+
+ EUnit _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e'
+
+ EReplicate1Inner _ a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b']
+
+ EMaximum1Inner _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e'
+
+ EMinimum1Inner _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+
+ EReshape _ n esh e -> do
+ esh' <- ppExpr' 11 val esh
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e']
+
+ EZip _ e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
+
+ EFold1InnerD1 _ cm a b c -> do
+ name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
+ let opname = "fold1iD1" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
+
+ EFold1InnerD2 _ cm ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ namef1 <- genNameIfUsedIn tB (IS IZ) ef
+ namef2 <- genNameIfUsedIn t2 IZ ef
+ ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
+ ebog' <- ppExpr' 11 val ebog
+ ed' <- ppExpr' 11 val ed
+ let opname = "fold1iD2" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr)
+ [ppLam [ppString namef1, ppString namef2] ef', ebog', ed']
+
+ EConst _ ty v
+ | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
+
+ EIdx0 _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e'
+
+ EIdx1 _ a b -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 9 val b
+ return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b'
+
+ EIdx _ a b -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 10 val b
+ return $ ppParen (d > 8) $
+ a' <+> ppString "!" <> ppX expr <+> b'
+
+ EShape _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e'
+
+ EOp _ op (EPair _ a b)
+ | (Infix, ops) <- operator op -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 9 val b
+ return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b'
+
+ EOp _ op e -> do
+ e' <- ppExpr' 11 val e
+ let ops = case operator op of
+ (Infix, s) -> "(" ++ s ++ ")"
+ (Prefix, s) -> s
+ return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e'
+
+ ECustom _ t1 t2 t3 a b c e1 e2 -> do
+ en1 <- genNameIfUsedIn t1 (IS IZ) a
+ en2 <- genNameIfUsedIn t2 IZ a
+ pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b
+ pn2 <- genNameIfUsedIn (d1 t2) IZ b
+ dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c
+ dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
+ a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a
+ b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b
+ c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ ppParen (d > 10) $
+ ppApp (ppString "custom" <> ppX expr)
+ [ppLam [ppString en1, ppString en2] a'
+ ,ppLam [ppString pn1, ppString pn2] b'
+ ,ppLam [ppString dn1, ppString dn2] c'
+ ,e1'
+ ,e2']
+
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
+ EWith _ t e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
+ e2' <- ppExpr' 0 (Const name `SCons` val) e2
+ return $ ppParen (d > 0) $
+ group $ flatAlt
+ (hang 2 $
+ annotate AWith (ppString "with") <> ppX expr <+> e1'
+ <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
+ <> hardline <> e2')
+ (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
+
+ EAccum _ t prj e1 sp e2 e3 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ e3' <- ppExpr' 11 val e3
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
+ [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
+
+ EZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
+ EDeepZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
+ EPlus _ t a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b']
+
+ EOneHot _ t prj a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b']
+
+ EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
+
+ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
+ppExprLet d val etop = do
+ let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
+ collect val' (ELet _ rhs body) = do
+ let occ = occCount IZ body
+ name <- genNameIfUsedIn (typeOf rhs) IZ body
+ rhs' <- ppExpr' 0 val' rhs
+ (binds, core) <- collect (Const name `SCons` val') body
+ return ((name, occ, rhs') : binds, core)
+ collect val' e = ([],) <$> ppExpr' 0 val' e
+
+ (binds, core) <- collect val etop
+
+ return $ ppParen (d > 0) $
+ align $
+ annotate AKey (ppString "let")
+ <+> align (mconcat $ intersperse hardline $
+ map (\(name, _occ, rhs) ->
+ ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs)
+ binds)
+ <> hardline <> annotate AKey (ppString "in") <+> core
+
+ppApp :: ADoc -> [ADoc] -> ADoc
+ppApp fun args = group $ fun <+> align (sep args)
+
+ppLam :: [ADoc] -> ADoc -> ADoc
+ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
+ <> softline <> body <> ppString ")")
+
+ppAcPrj :: SMTy a -> SAcPrj p a b -> String
+ppAcPrj _ SAPHere = "."
+ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
+ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)"
+ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
+ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
+
+ppSparse :: SMTy a -> Sparse a b -> String
+ppSparse t sp | Just Refl <- isDense t sp = "D"
+ppSparse _ SpAbsent = "A"
+ppSparse t (SpSparse s) = "S" ++ ppSparse t s
+ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
+ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
+ppSparse (SMTScal _) SpScal = "."
+
+ppCommut :: Commutative -> String
+ppCommut Commut = "(C)"
+ppCommut Noncommut = ""
+
+ppX :: PrettyX x => Expr x env t -> ADoc
+ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
+
+data Fixity = Prefix | Infix
+ deriving (Show)
+
+operator :: SOp a t -> (Fixity, String)
+operator OAdd{} = (Infix, "+")
+operator OMul{} = (Infix, "*")
+operator ONeg{} = (Prefix, "negate")
+operator OLt{} = (Infix, "<")
+operator OLe{} = (Infix, "<=")
+operator OEq{} = (Infix, "==")
+operator ONot = (Prefix, "not")
+operator OAnd = (Infix, "&&")
+operator OOr = (Infix, "||")
+operator OIf = (Prefix, "ifB")
+operator ORound64 = (Prefix, "round")
+operator OToFl64 = (Prefix, "toFl64")
+operator ORecip{} = (Prefix, "recip")
+operator OExp{} = (Prefix, "exp")
+operator OLog{} = (Prefix, "log")
+operator OIDiv{} = (Infix, "`div`")
+operator OMod{} = (Infix, "`mod`")
+
+ppSTy :: Int -> STy t -> String
+ppSTy d ty = render $ ppSTy' d ty
+
+ppSTy' :: Int -> STy t -> Doc q
+ppSTy' _ STNil = ppString "1"
+ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
+ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
+ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
+ppSTy' d (STArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
+ppSTy' _ (STScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
+ STBool -> "bool"
+ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+
+ppSMTy :: Int -> SMTy t -> String
+ppSMTy d ty = render $ ppSMTy' d ty
+
+ppSMTy' :: Int -> SMTy t -> Doc q
+ppSMTy' _ SMTNil = ppString "1"
+ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b
+ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b
+ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t
+ppSMTy' d (SMTArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t
+ppSMTy' _ (SMTScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
+
+ppString :: String -> Doc x
+ppString = fromString
+
+ppParen :: Bool -> Doc x -> Doc x
+ppParen True = parens
+ppParen False = id
+
+intSubscript :: Int -> String
+intSubscript = \case 0 -> "₀"
+ n | n < 0 -> '₋' : go (-n) ""
+ | otherwise -> go n ""
+ where go 0 suff = suff
+ go n suff = let (q, r) = n `quotRem` 10
+ in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff)
+
+data Annot = AKey | AWith | AHighlight | AMonoid | AExt
+ deriving (Show)
+
+annotToANSI :: Annot -> PT.AnsiStyle
+annotToANSI AKey = PT.bold
+annotToANSI AWith = PT.color PT.Red <> PT.underlined
+annotToANSI AHighlight = PT.color PT.Blue
+annotToANSI AMonoid = PT.color PT.Green
+annotToANSI AExt = PT.colorDull PT.White
+
+type ADoc = Doc Annot
+
+render :: Doc Annot -> String
+render =
+ (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI
+ else renderString)
+ . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 }
+ where
+ stdoutTTY = unsafePerformIO $ hSupportsANSI stdout
diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs
new file mode 100644
index 0000000..9156160
--- /dev/null
+++ b/src/CHAD/AST/Sparse.hs
@@ -0,0 +1,287 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
+module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where
+
+import Data.Type.Equality
+
+import CHAD.AST
+import CHAD.AST.Sparse.Types
+import CHAD.Data (SBool(..))
+
+
+sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
+sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
+sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
+sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
+sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
+ eunPair e1 $ \w1 e1a e1b ->
+ eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
+ EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
+ (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
+sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
+ elet e2 $
+ elcase (weakenExpr WSink e1)
+ (evar IZ)
+ (elcase (evar (IS IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
+ (elcase (evar (IS IZ))
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
+ elet e2 $
+ emaybe (weakenExpr WSink e1)
+ (evar IZ)
+ (emaybe (evar (IS IZ))
+ (EJust ext (evar IZ))
+ (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
+sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+
+
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
+
+
+data Injection sp a b where
+ -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
+ -- 'sparsePlusS' can provide injections even if the caller doesn't require
+ -- them. This simplifies the sparsePlusS code.
+ Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
+ Noinj :: Injection False a b
+
+withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
+withInj (Inj f) k = Inj (k f)
+withInj Noinj _ = Noinj
+
+withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
+ -> ((forall e. Ex e a1 -> Ex e b1)
+ -> (forall e. Ex e a2 -> Ex e b2)
+ -> (forall e'. Ex e' a' -> Ex e' b'))
+ -> Injection sp a' b'
+withInj2 (Inj f) (Inj g) k = Inj (k f g)
+withInj2 Noinj _ _ = Noinj
+withInj2 _ Noinj _ = Noinj
+
+-- | This function produces quadratically-sized code in the presence of nested
+-- dynamic sparsity. TODO can this be improved?
+sparsePlusS
+ :: SBool inj1 -> SBool inj2
+ -> SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
+ -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+-- nil override (but don't destroy effects!)
+sparsePlusS _ _ SMTNil _ _ k =
+ k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
+
+-- simplifications
+sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
+ sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
+sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
+ sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
+
+sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
+ let ta = applySparse sp1 (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
+ let tb = applySparse sp2 (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
+ let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
+ let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
+ let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
+ let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
+sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
+
+-- TODO: sparse of Just is just Maybe
+
+-- dense plus
+sparsePlusS _ _ t sp1 sp2 k
+ | Just Refl <- isDense t sp1
+ , Just Refl <- isDense t sp2
+ = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
+
+-- handle absents
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
+
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
+
+-- double sparse yields sparse
+sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- single sparse can yield non-sparse if the other argument is always present
+sparsePlusS SF _ t (SpSparse sp1) sp2 k =
+ sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
+ k sp3 Noinj (Inj inj2)
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (inj2 (evar IZ))
+ (plus (evar IZ) (evar (IS IZ))))
+sparsePlusS ST _ t (SpSparse sp1) sp2 k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> EJust ext (inj2 b))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (EJust ext (inj2 (evar IZ)))
+ (EJust ext (plus (evar IZ) (evar (IS IZ)))))
+sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
+ sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
+ k sp3 inj2 inj1 (flip plus)
+
+-- products
+sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
+ sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
+ sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
+ k (SpPair sp3a sp3b)
+ (withInj2 minj13a minj13b $ \inj13a inj13b ->
+ \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
+ (withInj2 minj23a minj23b $ \inj23a inj23b ->
+ \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
+ (\x1 x2 ->
+ eunPair x1 $ \w1 x1a x1b ->
+ eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
+ EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
+
+-- coproducts
+sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
+ sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
+ sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
+ let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
+ inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
+ inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
+ in
+ k (SpLEither sp3a sp3b)
+ (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
+ (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
+ (\x1 x2 ->
+ elet x2 $
+ elcase (weakenExpr WSink x1)
+ (elcase (evar IZ)
+ nil
+ (inl (inj23a (evar IZ)))
+ (inr (inj23b (evar IZ))))
+ (elcase (evar (IS IZ))
+ (inl (inj13a (evar IZ)))
+ (inl (plusa (evar (IS IZ)) (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
+ (elcase (evar (IS IZ))
+ (inr (inj13b (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
+ (inr (plusb (evar (IS IZ)) (evar IZ)))))
+
+-- maybe
+sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpMaybe sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- dense array cotangents simply recurse
+sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
+ k (SpArr sp3)
+ (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
+ (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+ (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
+ (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+
+-- scalars
+sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
new file mode 100644
index 0000000..8f41ba4
--- /dev/null
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -0,0 +1,107 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.AST.Sparse.Types where
+
+import Data.Kind (Type, Constraint)
+import Data.Type.Equality
+
+import CHAD.AST.Types
+
+
+data Sparse t t' where
+ SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
+ SpAbsent :: Sparse t TNil
+
+ SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
+ SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
+ SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
+ SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpScal :: Sparse (TScal t) (TScal t)
+deriving instance Show (Sparse t t')
+
+class ApplySparse f where
+ applySparse :: Sparse t t' -> f t -> f t'
+
+instance ApplySparse STy where
+ applySparse (SpSparse s) t = STMaybe (applySparse s t)
+ applySparse SpAbsent _ = STNil
+ applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+ applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse SpScal t = t
+
+instance ApplySparse SMTy where
+ applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
+ applySparse SpAbsent _ = SMTNil
+ applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
+ applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse SpScal t = t
+
+
+class IsSubType s where
+ type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
+ subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
+ subtTrans :: s a b -> s b c -> s a c
+ subtFull :: IsSubTypeSubject s f => f t -> s t t
+
+instance IsSubType (:~:) where
+ type IsSubTypeSubject (:~:) f = ()
+ subtApply = gcastWith
+ subtTrans = trans
+ subtFull _ = Refl
+
+instance IsSubType Sparse where
+ type IsSubTypeSubject Sparse f = f ~ SMTy
+ subtApply = applySparse
+
+ subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
+ subtTrans _ SpAbsent = SpAbsent
+ subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
+ subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
+ subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
+ subtTrans SpScal SpScal = SpScal
+
+ subtFull = spDense
+
+spDense :: SMTy t -> Sparse t t
+spDense SMTNil = SpAbsent
+spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
+spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
+spDense (SMTMaybe t) = SpMaybe (spDense t)
+spDense (SMTArr _ t) = SpArr (spDense t)
+spDense (SMTScal _) = SpScal
+
+isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
+isDense SMTNil SpAbsent = Just Refl
+isDense _ SpSparse{} = Nothing
+isDense _ SpAbsent = Nothing
+isDense (SMTPair t1 t2) (SpPair s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTLEither t1 t2) (SpLEither s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTMaybe t) (SpMaybe s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTArr _ t) (SpArr s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTScal _) SpScal = Just Refl
+
+isAbsent :: Sparse t t' -> Bool
+isAbsent (SpSparse s) = isAbsent s
+isAbsent SpAbsent = True
+isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpMaybe s) = isAbsent s
+isAbsent (SpArr s) = isAbsent s
+isAbsent SpScal = False
diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs
new file mode 100644
index 0000000..34267e4
--- /dev/null
+++ b/src/CHAD/AST/SplitLets.hs
@@ -0,0 +1,191 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST.SplitLets (splitLets) where
+
+import Data.Type.Equality
+
+import CHAD.AST
+import CHAD.AST.Bindings
+import CHAD.Lemmas
+
+
+splitLets :: Ex env t -> Ex env t
+splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
+
+splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
+splitLets' = \sub -> \case
+ EVar _ t i -> sub t i WId
+ ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ECase x e a b ->
+ let STEither t1 t2 = typeOf e
+ in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
+ EMaybe x a b e ->
+ let STMaybe t1 = typeOf e
+ in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
+ ELCase x e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
+ EFold1Inner x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD1 x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD2 x cm a b c ->
+ let STArr _ tB = typeOf b
+ STArr _ t2 = typeOf c
+ in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c)
+
+ EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
+ EFst x e -> EFst x (splitLets' sub e)
+ ESnd x e -> ESnd x (splitLets' sub e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (splitLets' sub e)
+ EInr x t e -> EInr x t (splitLets' sub e)
+ ENothing x t -> ENothing x t
+ EJust x e -> EJust x (splitLets' sub e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (splitLets' sub e)
+ ELInr x t e -> ELInr x t (splitLets' sub e)
+ EConstArr x n t a -> EConstArr x n t a
+ EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
+ EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b)
+ ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
+ EUnit x e -> EUnit x (splitLets' sub e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
+ EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
+ EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
+ EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b)
+ EConst x t v -> EConst x t v
+ EIdx0 x e -> EIdx0 x (splitLets' sub e)
+ EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
+ EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es)
+ EShape x e -> EShape x (splitLets' sub e)
+ EOp x op e -> EOp x op (splitLets' sub e)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ ERecompute x e -> ERecompute x (splitLets' sub e)
+ EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
+ EZero x t ezi -> EZero x t (splitLets' sub ezi)
+ EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)
+ EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
+ EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
+ EError x t s -> EError x t s
+ where
+ sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
+ sinkF _ t IZ w = EVar ext t (w @> IZ)
+ sinkF f t (IS i) w = f t i (w .> WSink)
+
+ split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t
+ split1 sub (tbind :: STy bind) body =
+ let (ptrs, bs) = split tbind
+ in letBinds bs $
+ splitLets' (\cases _ IZ w -> subPointers ptrs w
+ t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w)))
+ body
+
+ split2 :: forall bind1 bind2 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
+ split2 sub tbind1 tbind2 body =
+ let (ptrs1', bs1') = split @env' tbind1
+ bs1 = fst (weakenBindingsE WSink bs1')
+ (ptrs2, bs2) = split @(bind1 : env') tbind2
+ in letBinds bs1 $
+ letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
+ _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env')))
+ t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
+ body
+
+ -- TODO: abstract this to splitN lol wtf
+ _split4 :: forall bind1 bind2 bind3 bind4 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
+ _split4 sub tbind1 tbind2 tbind3 tbind4 body =
+ let (ptrs1, bs1') = split @env' tbind1
+ (ptrs2, bs2') = split @(bind1 : env') tbind2
+ (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
+ (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4
+ bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1')
+ bs2 = fst (weakenBindingsE (WSink .> WSink) bs2')
+ bs3 = fst (weakenBindingsE WSink bs3')
+ b1 = bindingsBinds bs1
+ b2 = bindingsBinds bs2
+ b3 = bindingsBinds bs3
+ b4 = bindingsBinds bs4
+ in letBinds bs1 $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1))
+ _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink))
+ _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink))
+ _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env')))
+ t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w)))))))))
+ body
+
+type family Split t where
+ Split (TPair a b) = SplitRec (TPair a b)
+ Split _ = '[]
+
+type family SplitRec t where
+ SplitRec TNil = '[]
+ SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a)
+ SplitRec t = '[t]
+
+data Pointers env t where
+ Point :: STy t -> Idx env t -> Pointers env t
+ PNil :: Pointers env TNil
+ PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b)
+ PWeak :: env' :> env -> Pointers env' t -> Pointers env t
+
+subPointers :: Pointers env t -> env :> env' -> Ex env' t
+subPointers (Point t i) w = EVar ext t (w @> i)
+subPointers PNil _ = ENil ext
+subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w)
+subPointers (PWeak w' p) w = subPointers p (w .> w')
+
+split :: forall env t. STy t
+ -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t))
+split typ = case typ of
+ STPair{} -> splitRec (EVar ext typ IZ) typ
+ STNil -> other
+ STEither{} -> other
+ STLEither{} -> other
+ STMaybe{} -> other
+ STArr{} -> other
+ STScal{} -> other
+ STAccum{} -> other
+ where
+ other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
+ other = (Point typ IZ, BTop)
+
+splitRec :: forall env t. Ex env t -> STy t
+ -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t))
+splitRec rhs typ = case typ of
+ STNil -> (PNil, BTop)
+ STPair (a :: STy a) (b :: STy b)
+ | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env ->
+ let (p1, bs1) = splitRec (EFst ext rhs) a
+ (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
+ in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
+ STEither{} -> other
+ STLEither{} -> other
+ STMaybe{} -> other
+ STArr{} -> other
+ STScal{} -> other
+ STAccum{} -> other
+ where
+ other :: (Pointers (t : env) t, Bindings Ex env '[t])
+ other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
new file mode 100644
index 0000000..059077d
--- /dev/null
+++ b/src/CHAD/AST/Types.hs
@@ -0,0 +1,215 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.AST.Types where
+
+import Data.Int (Int32, Int64)
+import Data.GADT.Compare
+import Data.GADT.Show
+import Data.Kind (Type)
+import Data.Type.Equality
+
+import CHAD.Data
+
+
+type data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TLEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty -- ^ contained type must be a monoid type
+
+type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+
+type STy :: Ty -> Type
+data STy t where
+ STNil :: STy TNil
+ STPair :: STy a -> STy b -> STy (TPair a b)
+ STEither :: STy a -> STy b -> STy (TEither a b)
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
+ STMaybe :: STy a -> STy (TMaybe a)
+ STArr :: SNat n -> STy t -> STy (TArr n t)
+ STScal :: SScalTy t -> STy (TScal t)
+ STAccum :: SMTy t -> STy (TAccum t)
+deriving instance Show (STy t)
+
+instance GCompare STy where
+ gcompare = \cases
+ STNil STNil -> GEQ
+ STNil _ -> GLT ; _ STNil -> GGT
+ (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STPair{} _ -> GLT ; _ STPair{} -> GGT
+ (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STLEither{} _ -> GLT ; _ STLEither{} -> GGT
+ (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
+ STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
+ (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ STArr{} _ -> GLT ; _ STArr{} -> GGT
+ (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
+ STScal{} _ -> GLT ; _ STScal{} -> GGT
+ (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+
+instance TestEquality STy where testEquality = geq
+instance GEq STy where geq = defaultGeq
+instance GShow STy where gshowsPrec = defaultGshowsPrec
+
+-- | Monoid types
+type SMTy :: Ty -> Type
+data SMTy t where
+ SMTNil :: SMTy TNil
+ SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b)
+ SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b)
+ SMTMaybe :: SMTy a -> SMTy (TMaybe a)
+ SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
+ SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+deriving instance Show (SMTy t)
+
+instance GCompare SMTy where
+ gcompare = \cases
+ SMTNil SMTNil -> GEQ
+ SMTNil _ -> GLT ; _ SMTNil -> GGT
+ (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT
+ (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT
+ (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a')
+ SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT
+ (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
+ (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
+ -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+
+instance TestEquality SMTy where testEquality = geq
+instance GEq SMTy where geq = defaultGeq
+instance GShow SMTy where gshowsPrec = defaultGshowsPrec
+
+fromSMTy :: SMTy t -> STy t
+fromSMTy = \case
+ SMTNil -> STNil
+ SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2)
+ SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2)
+ SMTMaybe t -> STMaybe (fromSMTy t)
+ SMTArr n t -> STArr n (fromSMTy t)
+ SMTScal sty -> STScal sty
+
+data SScalTy t where
+ STI32 :: SScalTy TI32
+ STI64 :: SScalTy TI64
+ STF32 :: SScalTy TF32
+ STF64 :: SScalTy TF64
+ STBool :: SScalTy TBool
+deriving instance Show (SScalTy t)
+
+instance GCompare SScalTy where
+ gcompare = \cases
+ STI32 STI32 -> GEQ
+ STI32 _ -> GLT ; _ STI32 -> GGT
+ STI64 STI64 -> GEQ
+ STI64 _ -> GLT ; _ STI64 -> GGT
+ STF32 STF32 -> GEQ
+ STF32 _ -> GLT ; _ STF32 -> GGT
+ STF64 STF64 -> GEQ
+ STF64 _ -> GLT ; _ STF64 -> GGT
+ STBool STBool -> GEQ
+ -- STBool _ -> GLT ; _ STBool -> GGT
+
+instance TestEquality SScalTy where testEquality = geq
+instance GEq SScalTy where geq = defaultGeq
+instance GShow SScalTy where gshowsPrec = defaultGshowsPrec
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
+type TIx = TScal TI64
+
+tIx :: STy TIx
+tIx = STScal STI64
+
+type family ScalRep t where
+ ScalRep TI32 = Int32
+ ScalRep TI64 = Int64
+ ScalRep TF32 = Float
+ ScalRep TF64 = Double
+ ScalRep TBool = Bool
+
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False
+
+type family ScalIsFloating t where
+ ScalIsFloating TI32 = False
+ ScalIsFloating TI64 = False
+ ScalIsFloating TF32 = True
+ ScalIsFloating TF64 = True
+ ScalIsFloating TBool = False
+
+type family ScalIsIntegral t where
+ ScalIsIntegral TI32 = True
+ ScalIsIntegral TI64 = True
+ ScalIsIntegral TF32 = False
+ ScalIsIntegral TF64 = False
+ ScalIsIntegral TBool = False
+
+-- | Returns true for arrays /and/ accumulators.
+typeHasArrays :: STy t' -> Bool
+typeHasArrays STNil = False
+typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STMaybe t) = typeHasArrays t
+typeHasArrays STArr{} = True
+typeHasArrays STScal{} = False
+typeHasArrays STAccum{} = True
+
+typeHasAccums :: STy t' -> Bool
+typeHasAccums STNil = False
+typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STMaybe t) = typeHasAccums t
+typeHasAccums STArr{} = False
+typeHasAccums STScal{} = False
+typeHasAccums STAccum{} = True
+
+type family Tup env where
+ Tup '[] = TNil
+ Tup (t : ts) = TPair (Tup ts) t
+
+mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
+ -> SList f list -> f (Tup list)
+mkTup nil _ SNil = nil
+mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
+
+tTup :: SList STy env -> STy (Tup env)
+tTup = mkTup STNil STPair
+
+unTup :: (forall a b. c (TPair a b) -> (c a, c b))
+ -> SList f list -> c (Tup list) -> SList c list
+unTup _ SNil _ = SNil
+unTup unpack (_ `SCons` list) tup =
+ let (xs, x) = unpack tup
+ in x `SCons` unTup unpack list xs
+
+type family InvTup core env where
+ InvTup core '[] = core
+ InvTup core (t : ts) = InvTup (TPair core t) ts
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
new file mode 100644
index 0000000..27c5f0a
--- /dev/null
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -0,0 +1,255 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
+
+import CHAD.AST
+import CHAD.AST.Sparse.Types
+import CHAD.Data
+
+
+-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
+-- expanding them into their concrete implementations. Also ensure that
+-- 'EAccum' has a dense sparsity.
+unMonoid :: Ex env t -> Ex env t
+unMonoid = \case
+ EZero _ t e -> zero t e
+ EDeepZero _ t e -> deepZero t e
+ EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
+ EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
+
+ EVar _ t i -> EVar ext t i
+ ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
+ EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
+ EFst _ e -> EFst ext (unMonoid e)
+ ESnd _ e -> ESnd ext (unMonoid e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext t (unMonoid e)
+ EInr _ t e -> EInr ext t (unMonoid e)
+ ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
+ ENothing _ t -> ENothing ext t
+ EJust _ e -> EJust ext (unMonoid e)
+ EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unMonoid e)
+ ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
+ EConstArr _ n t x -> EConstArr ext n t x
+ EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
+ EUnit _ e -> EUnit ext (unMonoid e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
+ EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
+ EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ EConst _ t x -> EConst ext t x
+ EIdx0 _ e -> EIdx0 ext (unMonoid e)
+ EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
+ EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
+ EShape _ e -> EShape ext (unMonoid e)
+ EOp _ op e -> EOp ext op (unMonoid e)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
+ EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
+ EError _ t s -> EError ext t s
+
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+-- don't destroy the effects!
+zero SMTNil e = ELet ext e $ ENil ext
+zero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
+zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
+zero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero SMTNil e = elet e $ ENil ext
+deepZero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+deepZero (SMTLEither t1 t2) e =
+ elcase e
+ (ELNil ext (fromSMTy t1) (fromSMTy t2))
+ (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
+ (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+deepZero (SMTMaybe t) e =
+ emaybe e
+ (ENothing ext (fromSMTy t))
+ (EJust ext (deepZero t (evar IZ)))
+deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
+deepZero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
+plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
+-- don't destroy the effects!
+plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
+plus (SMTPair t1 t2) a b =
+ let t = STPair (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
+ (EFst ext (EVar ext t IZ)))
+ (plus t2 (ESnd ext (EVar ext t (IS IZ)))
+ (ESnd ext (EVar ext t IZ)))
+plus (SMTLEither t1 t2) a b =
+ let t = STLEither (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t IZ)
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
+ (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
+ (EError ext t "plus l+r"))
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
+ (EError ext t "plus r+l")
+ (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
+plus (SMTMaybe t) a b =
+ ELet ext b $
+ EMaybe ext
+ (EVar ext (STMaybe (fromSMTy t)) IZ)
+ (EJust ext
+ (EMaybe ext
+ (EVar ext (fromSMTy t) IZ)
+ (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
+ (weakenExpr WSink a)
+plus (SMTArr _ t) a b =
+ ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ a b
+plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
+
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
+onehot typ topprj idx arg = case (typ, topprj) of
+ (_, SAPHere) ->
+ ELet ext arg $
+ EVar ext (fromSMTy typ) IZ
+
+ (SMTPair t1 t2, SAPFst prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t1 in
+ EPair ext (EVar ext toh IZ)
+ (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
+
+ (SMTPair t1 t2, SAPSnd prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t2 in
+ EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
+ (EVar ext toh IZ)
+
+ (SMTLEither t1 t2, SAPLeft prj) ->
+ ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ (SMTLEither t1 t2, SAPRight prj) ->
+ ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
+
+ (SMTMaybe t1, SAPJust prj) ->
+ EJust ext (onehot t1 prj idx arg)
+
+ (SMTArr n t1, SAPArrIdx prj) ->
+ let tidx = tTup (sreplicate n tIx)
+ in ELet ext idx $
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
+ zero t1 (EVar ext (tZeroInfo t1) IZ))
+
+accumulateSparse
+ :: SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse topty topsp arg accum = case (topty, topsp) of
+ (_, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere (ENil ext) arg
+ (SMTScal _, SpScal) ->
+ accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (_, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, SpAbsent) ->
+ ENil ext
+ (SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
+ ENil ext
+
+acPrjCompose
+ :: SAIDense dense
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
+ -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
+acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
+acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPFst p') idx'
+acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPSnd p') idx'
+acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
+acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
+acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPLeft p') idx'
+acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPRight p') idx'
+acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPJust p') idx'
+acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
diff --git a/src/CHAD/AST/Weaken.hs b/src/CHAD/AST/Weaken.hs
new file mode 100644
index 0000000..ac0d152
--- /dev/null
+++ b/src/CHAD/AST/Weaken.hs
@@ -0,0 +1,138 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+
+-- The reason why this is a separate module with "little" in it:
+{-# LANGUAGE AllowAmbiguousTypes #-}
+
+module CHAD.AST.Weaken (module CHAD.AST.Weaken, Append) where
+
+import Data.Bifunctor (first)
+import Data.Functor.Const
+import Data.GADT.Compare
+import Data.Kind (Type)
+
+import CHAD.Data
+import CHAD.Lemmas
+
+
+type Idx :: [k] -> k -> Type
+data Idx env t where
+ IZ :: Idx (t : env) t
+ IS :: Idx env t -> Idx (a : env) t
+deriving instance Show (Idx env t)
+
+instance GEq (Idx env) where
+ geq IZ IZ = Just Refl
+ geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl
+ geq _ _ = Nothing
+
+splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
+splitIdx SNil i = Right i
+splitIdx (SCons _ _) IZ = Left IZ
+splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
+
+slistIdx :: SList f list -> Idx list t -> f t
+slistIdx (SCons x _) IZ = x
+slistIdx (SCons _ list) (IS i) = slistIdx list i
+slistIdx SNil i = case i of {}
+
+idx2int :: Idx env t -> Int
+idx2int IZ = 0
+idx2int (IS n) = 1 + idx2int n
+
+data env :> env' where
+ WId :: env :> env
+ WSink :: forall t env. env :> (t : env)
+ WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env')
+ WPop :: (t : env) :> env' -> env :> env'
+ WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
+ WClosed :: '[] :> env
+ WIdx :: Idx env t -> (t : env) :> env
+ WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env'
+ -> Append pre (t : env) :> t : Append pre env'
+ WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs
+ -> Append as (Append bs env) :> Append bs (Append as env)
+ WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs
+ -> as :> bs -> env1 :> env2
+ -> Append as env1 :> Append bs env2
+deriving instance Show (env :> env')
+infix 4 :>
+
+infixr 2 @>
+(@>) :: env :> env' -> Idx env t -> Idx env' t
+WId @> i = i
+WSink @> i = IS i
+WCopy _ @> IZ = IZ
+WCopy w @> IS i = IS (w @> i)
+WPop w @> i = w @> IS i
+WThen w1 w2 @> i = w2 @> w1 @> i
+WClosed @> i = case i of {}
+WIdx j @> IZ = j
+WIdx _ @> IS i = i
+WPick SNil w @> i = WCopy w @> i
+WPick (_ `SCons` _) _ @> IZ = IS IZ
+WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i
+WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i =
+ case splitIdx @(Append bs env) as i of
+ Left i' -> indexSinks bs (indexRaiseAbove @env as i')
+ Right i' -> case splitIdx @env bs i' of
+ Left j -> indexRaiseAbove @(Append as env) bs j
+ Right j -> indexSinks bs (indexSinks as j)
+WStack @env1 @env2 as bs wlo whi @> i =
+ case splitIdx @env1 as i of
+ Left i' -> indexRaiseAbove @env2 bs (wlo @> i')
+ Right i' -> indexSinks bs (whi @> i')
+
+indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t
+indexSinks SNil j = j
+indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j)
+
+indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t
+indexRaiseAbove = flip go
+ where
+ go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t
+ go IZ (_ `SCons` _) = IZ
+ go (IS i) (_ `SCons` as) = IS (go i as)
+
+infixr 3 .>
+(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
+(.>) = flip WThen
+
+class KnownListSpine list where knownListSpine :: SList (Const ()) list
+instance KnownListSpine '[] where knownListSpine = SNil
+instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine
+
+wSinks' :: forall list env. KnownListSpine list => env :> Append list env
+wSinks' = wSinks (knownListSpine :: SList (Const ()) list)
+
+wSinks :: forall env bs f. SList f bs -> env :> Append bs env
+wSinks SNil = WId
+wSinks (SCons _ spine) = WSink .> wSinks spine
+
+wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env'
+wSinksAnd SNil w = w
+wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w
+
+wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
+wCopies bs w =
+ let bs' = slistMap (\_ -> Const ()) bs
+ in WStack bs' bs' WId w
+
+wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env
+wRaiseAbove SNil _ = WClosed
+wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
+
+wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2
+wPops SNil w = w
+wPops (_ `SCons` bs) w = wPops bs (WPop w)
diff --git a/src/CHAD/AST/Weaken/Auto.hs b/src/CHAD/AST/Weaken/Auto.hs
new file mode 100644
index 0000000..14d8c59
--- /dev/null
+++ b/src/CHAD/AST/Weaken/Auto.hs
@@ -0,0 +1,192 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE FunctionalDependencies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+{-# LANGUAGE AllowAmbiguousTypes #-}
+
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
+module CHAD.AST.Weaken.Auto (
+ autoWeak,
+ (&.), auto, auto1,
+ Layout(..),
+) where
+
+import Data.Functor.Const
+import Data.Kind (Constraint)
+import GHC.OverloadedLabels
+import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
+
+import CHAD.AST.Weaken
+import CHAD.Data
+import CHAD.Lemmas
+
+
+type family Lookup name list where
+ Lookup name ('(name, x) : _) = x
+ Lookup name (_ : list) = Lookup name list
+ Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.")
+
+
+-- | The @withPre@ type parameter indicates whether there can be 'LPreW'
+-- occurrences within this layout. 'names' is the list of names that this
+-- layout /produces/. That is: for LPreW, it contains the target name. The
+-- 'names' list of a source layout must be a subset of the names list of the
+-- target layout (which cannot contain LPreW); this is checked with SubLayout.
+data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where
+ LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments)
+ -- | Pre-weaken with a weakening
+ LPreW :: forall name1 name2 segments.
+ SegmentName name1 -> SegmentName name2
+ -> Lookup name1 segments :> Lookup name2 segments
+ -> Layout True segments '[name2] (Lookup name1 segments)
+ (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2)
+infixr :++:
+
+instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where
+ fromLabel = LSeg (symbolSing @name)
+
+newtype SegmentName name = SegmentName (SSymbol name)
+ deriving (Show)
+
+instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where
+ fromLabel = SegmentName symbolSing
+
+
+type family SubLayout names1 names2 where
+ SubLayout '[] _ = () :: Constraint
+ SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2
+type family SubLayout' n ok names1 names2 where
+ SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.")
+ SubLayout' _ True names1 names2 = SubLayout names1 names2
+type family Contains n names where
+ Contains _ '[] = False
+ Contains n (n : _) = True
+ Contains n (_ : names) = Contains n names
+
+
+data SSegments (segments :: [(Symbol, [t])]) where
+ SSegNil :: SSegments '[]
+ SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
+
+instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where
+ fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil
+
+auto :: KnownListSpine list => SList (Const ()) list
+auto = knownListSpine
+
+auto1 :: SList (Const ()) '[t]
+auto1 = Const () `SCons` SNil
+
+infixr &.
+(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2)
+(&.) = ssegmentsAppend
+ where
+ ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
+ ssegmentsAppend SSegNil l2 = l2
+ ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2)
+
+
+-- | If the found segment is a TopSeg, returns Nothing.
+segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
+segmentLookup = \segs name -> case go segs name of
+ Just ts -> ts
+ Nothing -> error $ "Segment not found: " ++ fromSSymbol name
+ where
+ go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs'))
+ go SSegNil _ = Nothing
+ go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol =
+ case sameSymbol n name of
+ Just Refl ->
+ case go sseg name of
+ Nothing -> Just ts
+ Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name
+ Nothing ->
+ case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of
+ Refl -> go sseg name
+
+data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where
+ LinEnd :: LinLayout withPre segments '[]
+ LinApp :: SSymbol name -> LinLayout withPre segments env
+ -> LinLayout withPre segments (Append (Lookup name segments) env)
+ LinAppPreW :: SSymbol name1 -> SSymbol name2
+ -> Lookup name1 segments :> Lookup name2 segments
+ -> LinLayout True segments env
+ -> LinLayout True segments (Append (Lookup name1 segments) env)
+
+linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2)
+linLayoutAppend LinEnd lin = lin
+linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)
+ | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2
+ = LinApp name (linLayoutAppend lin1 lin2)
+linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)
+ | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2
+ = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2)
+
+lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env
+lineariseLayout (LSeg name :: Layout _ _ _ seg)
+ | Refl <- lemAppendNil @seg
+ = LinApp name LinEnd
+lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
+lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg)
+ | Refl <- lemAppendNil @seg
+ = LinAppPreW name1 name2 w LinEnd
+
+preWeaken :: SSegments segments -> LinLayout True segments env
+ -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r
+preWeaken _ LinEnd k = k WId LinEnd
+preWeaken segs (LinApp name lin) k =
+ preWeaken segs lin $ \w lin' ->
+ k (wCopies (segmentLookup segs name) w) (LinApp name lin')
+preWeaken segs (LinAppPreW name1 name2 weak lin) k =
+ preWeaken segs lin $ \w lin' ->
+ k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin')
+
+pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env
+ -> r -- Name was not found in source
+ -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r)
+ -> r
+pullDown segs name@SSymbol linlayout kNotFound k =
+ case linlayout of
+ LinEnd -> kNotFound
+ LinApp n'@SSymbol lin
+ | Just Refl <- sameSymbol name n' -> k lin WId
+ | otherwise ->
+ pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w ->
+ k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name)
+ .> wCopies (segmentLookup segs n') w)
+
+sortLinLayouts :: SSegments segments
+ -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2
+sortLinLayouts _ LinEnd LinEnd = WId
+sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
+ | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2)
+ | otherwise =
+ pullDown segs name2 lin1
+ (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2)
+ (\tail1' w ->
+ -- We've pulled down name2 in lin1 so that it's at the head; the
+ -- resulting modified tail is tail1'. Thus now we have (name2 : tail1')
+ -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
+ -- wCopies the name2 on top of that.
+ wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
+sortLinLayouts _ LinEnd LinApp{} = WClosed
+sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target"
+
+autoWeak :: SubLayout names1 names2
+ => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2
+autoWeak segs ly1 ly2 =
+ preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 ->
+ sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak
diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs
new file mode 100644
index 0000000..212cc7d
--- /dev/null
+++ b/src/CHAD/Analysis/Identity.hs
@@ -0,0 +1,436 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+module CHAD.Analysis.Identity (
+ identityAnalysis,
+ identityAnalysis',
+ ValId(..),
+ validSplitEither,
+) where
+
+import Data.Foldable (toList)
+import Data.List (intercalate)
+
+import CHAD.AST
+import CHAD.AST.Pretty (PrettyX(..))
+import CHAD.Data
+import CHAD.Drev.Types (d1, d2)
+import CHAD.Util.IdGen
+
+
+-- | Every array, scalar and accumulator has an ID. Trivial values such as
+-- Nothing only have the knowledge that they are indeed Nothing. Compound
+-- values know which values they consist of.
+data ValId t where
+ VINil :: ValId TNil
+ VIPair :: ValId a -> ValId b -> ValId (TPair a b)
+ VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
+ VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
+ VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
+ VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
+ VIArr :: Int -> Vec n Int -> ValId (TArr n t)
+ VIScal :: Int -> ValId (TScal t)
+ VIAccum :: Int -> ValId (TAccum t)
+deriving instance Show (ValId t)
+
+instance PrettyX ValId where
+ prettyX = \case
+ VINil -> ""
+ VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")"
+ VIEither (Left a) -> "(L" ++ prettyX a ++ ")"
+ VIEither (Right a) -> "(R" ++ prettyX a ++ ")"
+ VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")"
+ VIMaybe Nothing -> "N"
+ VIMaybe (Just a) -> 'J' : prettyX a
+ VIMaybe' a -> 'M' : prettyX a
+ VILEither (VIMaybe Nothing) -> "lN"
+ VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")"
+ VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))"
+ VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
+ VIScal i -> show i
+ VIAccum i -> 'C' : show i
+
+validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b))
+validSplitEither (VIEither (Left v)) = (Just v, Nothing)
+validSplitEither (VIEither (Right v)) = (Nothing, Just v)
+validSplitEither (VIEither' v1 v2) = (Just v1, Just v2)
+
+-- | Symbolic partial evaluation.
+identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
+identityAnalysis env term = runIdGen 0 $ do
+ env' <- slistMapA genIds env
+ snd <$> idana env' term
+
+identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t
+identityAnalysis' env term = snd (runIdGen 0 (idana env term))
+
+idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
+idana env expr = case expr of
+ EVar _ t i -> do
+ let v = slistIdx env i
+ pure (v, EVar v t i)
+
+ ELet _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (v2, e2') <- idana (v1 `SCons` env) e2
+ pure (v2, ELet v2 e1' e2')
+
+ EPair _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (v2, e2') <- idana env e2
+ pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2')
+
+ EFst _ e -> do
+ (v, e') <- idana env e
+ let VIPair v1 _ = v
+ pure (v1, EFst v1 e')
+
+ ESnd _ e -> do
+ (v, e') <- idana env e
+ let VIPair _ v2 = v
+ pure (v2, ESnd v2 e')
+
+ ENil _ -> pure (VINil, ENil VINil)
+
+ EInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VIEither (Left v1)
+ pure (v, EInl v t2 e1')
+
+ EInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VIEither (Right v2)
+ pure (v, EInr v t1 e2')
+
+ ECase _ e1 e2 e3 -> do
+ let STEither t1 t2 = typeOf e1
+ (v1, e1') <- idana env e1
+ case v1 of
+ VIEither (Left v1') -> do
+ (v2, e2') <- idana (v1' `SCons` env) e2
+ scrap <- genIds t2
+ (_, e3') <- idana (scrap `SCons` env) e3
+ pure (v2, ECase v2 e1' e2' e3')
+ VIEither (Right v1') -> do
+ scrap <- genIds t1
+ (_, e2') <- idana (scrap `SCons` env) e2
+ (v3, e3') <- idana (v1' `SCons` env) e3
+ pure (v3, ECase v3 e1' e2' e3')
+ VIEither' v1'l v1'r -> do
+ (v2, e2') <- idana (v1'l `SCons` env) e2
+ (v3, e3') <- idana (v1'r `SCons` env) e3
+ res <- unify v2 v3
+ pure (res, ECase res e1' e2' e3')
+
+ ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t)
+
+ EJust _ e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VIMaybe (Just v1)
+ pure (v, EJust v e1')
+
+ EMaybe _ e1 e2 e3 -> do
+ let STMaybe t1 = typeOf e3
+ (v3, e3') <- idana env e3
+ case v3 of
+ VIMaybe Nothing -> do
+ (v1, e1') <- idana env e1
+ scrap <- genIds t1
+ (_, e2') <- idana (scrap `SCons` env) e2
+ pure (v1, EMaybe v1 e1' e2' e3')
+ VIMaybe (Just v3j) -> do
+ (v2, e2') <- idana (v3j `SCons` env) e2
+ (_, e1') <- idana env e1
+ pure (v2, EMaybe v2 e1' e2' e3')
+ VIMaybe' v3' -> do
+ (v2, e2') <- idana (v3' `SCons` env) e2
+ (v1, e1') <- idana env e1
+ res <- unify v1 v2
+ pure (res, EMaybe res e1' e2' e3')
+
+ ELNil _ t1 t2 -> do
+ let v = VILEither (VIMaybe Nothing)
+ pure (v, ELNil v t1 t2)
+
+ ELInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VILEither (VIMaybe (Just (VIEither (Left v1))))
+ pure (v, ELInl v t2 e1')
+
+ ELInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VILEither (VIMaybe (Just (VIEither (Right v2))))
+ pure (v, ELInr v t1 e2')
+
+ ELCase _ e1 e2 e3 e4 -> do
+ let STLEither t1 t2 = typeOf e1
+ (v1L, e1') <- idana env e1
+ let VILEither v1 = v1L
+ let go mv1'l mv1'r f = do
+ v1'l <- maybe (genIds t1) pure mv1'l
+ v1'r <- maybe (genIds t2) pure mv1'r
+ (v2, e2') <- idana env e2
+ (v3, e3') <- idana (v1'l `SCons` env) e3
+ (v4, e4') <- idana (v1'r `SCons` env) e4
+ res <- f v2 v3 v4
+ pure (res, ELCase res e1' e2' e3' e4')
+ case v1 of
+ VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2)
+ VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3)
+ VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4)
+ VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4)
+ VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3)
+ VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4)
+ VIMaybe' (VIEither' v1'l v1'r) ->
+ go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4)
+
+ EConstArr _ dim t arr -> do
+ x1 <- VIArr <$> genId <*> vecReplicateA dim genId
+ pure (x1, EConstArr x1 dim t arr)
+
+ EBuild _ dim e1 e2 -> do
+ (shids, e1') <- idana env e1
+ x1 <- genIds (tTup (sreplicate dim tIx))
+ (_, e2') <- idana (x1 `SCons` env) e2
+ res <- VIArr <$> genId <*> shidsToVec dim shids
+ pure (res, EBuild res dim e1' e2')
+
+ EMap _ e1 e2 -> do
+ let STArr _ t = typeOf e2
+ x1 <- genIds t
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (v2, e2') <- idana env e2
+ let VIArr _ sh = v2
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMap res e1' e2')
+
+ EFold1Inner _ cm e1 e2 e3 -> do
+ let t1 = typeOf e1
+ x1 <- genIds (STPair t1 t1)
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (_, e2') <- idana env e2
+ (v3, e3') <- idana env e3
+ let VIArr _ (_ :< sh) = v3
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EFold1Inner res cm e1' e2' e3')
+
+ ESum1Inner _ e1 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ (_ :< sh) = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, ESum1Inner res e1')
+
+ EUnit _ e1 -> do
+ (_, e1') <- idana env e1
+ res <- VIArr <$> genId <*> pure VNil
+ pure (res, EUnit res e1')
+
+ EReplicate1Inner _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ let VIScal v1' = v1
+ (v2, e2') <- idana env e2
+ let VIArr _ sh = v2
+ res <- VIArr <$> genId <*> pure (v1' :< sh)
+ pure (res, EReplicate1Inner res e1' e2')
+
+ EMaximum1Inner _ e1 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ (_ :< sh) = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMaximum1Inner res e1')
+
+ EMinimum1Inner _ e1 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ (_ :< sh) = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMinimum1Inner res e1')
+
+ EReshape _ dim e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- VIArr <$> genId <*> shidsToVec dim v1
+ pure (res, EReshape res dim e1' e2')
+
+ EZip _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ let VIArr _ sh = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EZip res e1' e2')
+
+ EFold1InnerD1 _ cm e1 e2 e3 -> do
+ let t1 = typeOf e2
+ x1 <- genIds (STPair t1 t1)
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (_, e2') <- idana env e2
+ (v3, e3') <- idana env e3
+ let VIArr _ sh'@(_ :< sh) = v3
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh')
+ pure (res, EFold1InnerD1 res cm e1' e2' e3')
+
+ EFold1InnerD2 _ cm ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ xf1 <- genIds t2
+ xf2 <- genIds tB
+ (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef
+ (v2, e2') <- idana env ebog
+ (_, e3') <- idana env ed
+ let VIArr _ sh@(_ :< sh') = v2
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh)
+ pure (res, EFold1InnerD2 res cm e1' e2' e3')
+
+ EConst _ t val -> do
+ res <- VIScal <$> genId
+ pure (res, EConst res t val)
+
+ EIdx0 _ e1 -> do
+ (_, e1') <- idana env e1
+ res <- genIds (typeOf expr)
+ pure (res, EIdx0 res e1')
+
+ EIdx1 _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ sh = v1
+ (_, e2') <- idana env e2
+ res <- VIArr <$> genId <*> pure (vecInit sh)
+ pure (res, EIdx1 res e1' e2')
+
+ EIdx _ e1 e2 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- genIds (typeOf expr)
+ pure (res, EIdx res e1' e2')
+
+ EShape _ e1 -> do
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
+ let VIArr _ sh = v1
+ res = vecToShids dim sh
+ pure (res, EShape res e1')
+
+ EOp _ (op :: SOp a t) e1 -> do
+ (_, e1') <- idana env e1
+ res <- genIds (typeOf expr)
+ pure (res, EOp res op e1')
+
+ ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do
+ let t4 = typeOf e1
+ x1 <- genIds t2
+ x2 <- genIds t1
+ (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1
+ x3 <- genIds (d1 t2)
+ x4 <- genIds (d1 t1)
+ (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2
+ x5 <- genIds (d2 t4)
+ x6 <- genIds t3
+ (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3
+ (_, e4') <- idana env e4
+ (_, e5') <- idana env e5
+ res <- genIds t4
+ pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
+
+ ERecompute _ e -> do
+ (v, e') <- idana env e
+ pure (v, ERecompute v e')
+
+ EWith _ t e1 e2 -> do
+ let t1 = typeOf e1
+ (_, e1') <- idana env e1
+ x1 <- VIAccum <$> genId
+ (v2, e2') <- idana (x1 `SCons` env) e2
+ x2 <- genIds t1
+ let res = VIPair v2 x2
+ pure (res, EWith res t e1' e2')
+
+ EAccum _ t prj e1 sp e2 e3 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ (_, e3') <- idana env e3
+ pure (VINil, EAccum VINil t prj e1' sp e2' e3')
+
+ EZero _ t e1 -> do
+ -- Approximate the result of EZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EZero res t e1')
+
+ EDeepZero _ t e1 -> do
+ -- Approximate the result of EDeepZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EDeepZero res t e1')
+
+ EPlus _ t e1 e2 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- genIds (fromSMTy t)
+ pure (res, EPlus res t e1' e2')
+
+ EOneHot _ t i e1 e2 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- genIds (fromSMTy t)
+ pure (res, EOneHot res t i e1' e2')
+
+ EError _ t s -> do
+ res <- genIds t
+ pure (res, EError res t s)
+
+-- | This value might be either of the two arguments; we don't know which.
+unify :: ValId t -> ValId t -> IdGen (ValId t)
+unify VINil VINil = pure VINil
+unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d
+unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b
+unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b
+unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b
+unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a
+unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c
+unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c
+unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b
+unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c
+unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d
+unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing
+unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b
+unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a
+unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a
+unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a
+unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
+unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
+unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VILEither a) (VILEither b) = VILEither <$> unify a b
+unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
+unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
+unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
+
+unifyID :: Int -> Int -> IdGen Int
+unifyID i j | i == j = pure i
+ | otherwise = genId
+
+genIds :: STy t -> IdGen (ValId t)
+genIds STNil = pure VINil
+genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
+genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
+genIds (STMaybe t) = VIMaybe' <$> genIds t
+genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
+genIds STScal{} = VIScal <$> genId
+genIds STAccum{} = VIAccum <$> genId
+
+shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
+shidsToVec SZ _ = pure VNil
+shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is
+
+vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx))
+vecToShids SZ VNil = VINil
+vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i)
diff --git a/src/CHAD/Array.hs b/src/CHAD/Array.hs
new file mode 100644
index 0000000..f80f961
--- /dev/null
+++ b/src/CHAD/Array.hs
@@ -0,0 +1,131 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TupleSections #-}
+module CHAD.Array where
+
+import Control.DeepSeq
+import Control.Monad.Trans.State.Strict
+import Data.Foldable (traverse_)
+import Data.Vector (Vector)
+import qualified Data.Vector as V
+import GHC.Generics (Generic)
+
+import CHAD.Data
+
+
+data Shape n where
+ ShNil :: Shape Z
+ ShCons :: Shape n -> Int -> Shape (S n)
+deriving instance Show (Shape n)
+deriving instance Eq (Shape n)
+
+instance NFData (Shape n) where
+ rnf ShNil = ()
+ rnf (sh `ShCons` n) = rnf n `seq` rnf sh
+
+data Index n where
+ IxNil :: Index Z
+ IxCons :: Index n -> Int -> Index (S n)
+deriving instance Show (Index n)
+deriving instance Eq (Index n)
+
+instance NFData (Index n) where
+ rnf IxNil = ()
+ rnf (sh `IxCons` n) = rnf n `seq` rnf sh
+
+shapeSize :: Shape n -> Int
+shapeSize ShNil = 1
+shapeSize (ShCons sh n) = shapeSize sh * n
+
+shapeRank :: Shape n -> SNat n
+shapeRank ShNil = SZ
+shapeRank (sh `ShCons` _) = SS (shapeRank sh)
+
+fromLinearIndex :: Shape n -> Int -> Index n
+fromLinearIndex ShNil 0 = IxNil
+fromLinearIndex ShNil _ = error "Index out of range"
+fromLinearIndex (sh `ShCons` n) i =
+ let (q, r) = i `quotRem` n
+ in fromLinearIndex sh q `IxCons` r
+
+toLinearIndex :: Shape n -> Index n -> Int
+toLinearIndex ShNil IxNil = 0
+toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i
+
+emptyShape :: SNat n -> Shape n
+emptyShape SZ = ShNil
+emptyShape (SS m) = emptyShape m `ShCons` 0
+
+enumShape :: Shape n -> [Index n]
+enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1]
+
+shapeToList :: Shape n -> [Int]
+shapeToList = go []
+ where
+ go :: [Int] -> Shape n -> [Int]
+ go suff ShNil = suff
+ go suff (sh `ShCons` n) = go (n:suff) sh
+
+
+-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
+data Array (n :: Nat) t = Array (Shape n) (Vector t)
+ deriving (Show, Functor, Foldable, Traversable, Generic)
+instance NFData t => NFData (Array n t)
+
+arrayShape :: Array n t -> Shape n
+arrayShape (Array sh _) = sh
+
+arraySize :: Array n t -> Int
+arraySize (Array sh _) = shapeSize sh
+
+emptyArray :: SNat n -> Array n t
+emptyArray n = Array (emptyShape n) V.empty
+
+arrayFromList :: Shape n -> [t] -> Array n t
+arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l)
+
+arrayToList :: Array n t -> [t]
+arrayToList (Array _ v) = V.toList v
+
+arrayReshape :: Shape n -> Array m t -> Array n t
+arrayReshape sh (Array sh' v)
+ | shapeSize sh == shapeSize sh' = Array sh v
+ | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")"
+
+arrayUnit :: t -> Array Z t
+arrayUnit x = Array ShNil (V.singleton x)
+
+arrayIndex :: Array n t -> Index n -> t
+arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
+
+arrayIndexLinear :: Array n t -> Int -> t
+arrayIndexLinear (Array _ v) i = v V.! i
+
+arrayIndex1 :: Array (S n) t -> Int -> Array n t
+arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v)
+
+arrayGenerate :: Shape n -> (Index n -> t) -> Array n t
+arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh)
+
+arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t
+arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f)
+
+arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)
+arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
+
+arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
+arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
+
+arrayMap :: (a -> b) -> Array n a -> Array n b
+arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr)
+
+arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b)
+arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr)
+
+-- | The Int is the linear index of the value.
+traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
+traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0
diff --git a/src/CHAD/Compile.hs b/src/CHAD/Compile.hs
new file mode 100644
index 0000000..5b71651
--- /dev/null
+++ b/src/CHAD/Compile.hs
@@ -0,0 +1,1796 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+module CHAD.Compile (compile, compileStderr) where
+
+import Control.Applicative (empty)
+import Control.Monad (forM_, when, replicateM)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Maybe
+import Control.Monad.Trans.State.Strict
+import Control.Monad.Trans.Writer.CPS
+import Data.Bifunctor (first)
+import Data.Char (ord)
+import Data.Foldable (toList)
+import Data.Functor.Const
+import qualified Data.Functor.Product as Product
+import Data.Functor.Product (Product)
+import Data.IORef
+import Data.List (foldl1', intersperse, intercalate)
+import qualified Data.Map.Strict as Map
+import Data.Maybe (fromMaybe)
+import qualified Data.Set as Set
+import Data.Set (Set)
+import Data.Some
+import qualified Data.Vector as V
+import Foreign
+import GHC.Exts (int2Word#, addr2Int#)
+import GHC.Num (integerFromWord#)
+import GHC.Ptr (Ptr(..))
+import GHC.Stack (HasCallStack)
+import Numeric (showHex)
+import System.IO (hPutStrLn, stderr)
+import System.IO.Error (mkIOError, userErrorType)
+import System.IO.Unsafe (unsafePerformIO)
+
+import Prelude hiding ((^))
+import qualified Prelude
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Pretty (ppSTy, ppExpr)
+import CHAD.AST.Sparse.Types (isDense)
+import CHAD.Compile.Exec
+import CHAD.Data
+import CHAD.Interpreter.Rep
+import qualified CHAD.Util.IdGen as IdGen
+
+
+-- In shape and index arrays, the innermost dimension is on the right (last index).
+
+-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
+
+
+-- | Print the compiled AST
+debugPrintAST :: Bool; debugPrintAST = toEnum 0
+-- | Print the generated C source
+debugCSource :: Bool; debugCSource = toEnum 0
+-- | Print extra stuff about reference counts of arrays
+debugRefc :: Bool; debugRefc = toEnum 0
+-- | Print some shape-related information
+debugShapes :: Bool; debugShapes = toEnum 0
+-- | Print information on allocation
+debugAllocs :: Bool; debugAllocs = toEnum 0
+-- | Emit extra C code that checks stuff
+emitChecks :: Bool; emitChecks = toEnum 0
+
+-- | Returns compiled function plus compilation output (warnings)
+compile :: SList STy env -> Ex env t
+ -> IO (SList Value env -> IO (Rep t), String)
+compile = \env expr -> do
+ codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
+
+ let (source, offsets) = compileToString codeID env expr
+ when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
+ when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
+ (lib, compileOutput) <- buildKernel source "kernel"
+
+ let result_type = typeOf expr
+ result_size = sizeofSTy result_type
+
+ let function val = do
+ allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
+ serialiseArguments args ptr $ do
+ callKernelFun lib ptr
+ ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
+ when (ok /= 1) $
+ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
+ deserialise result_type ptr (koResultOffset offsets)
+ return (function, compileOutput)
+ where
+ serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
+ serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
+ serialise t arg ptr off $
+ serialiseArguments args ptr k
+ serialiseArguments _ _ k = k
+
+-- | 'compile', but writes any produced C compiler output to stderr.
+compileStderr :: SList STy env -> Ex env t
+ -> IO (SList Value env -> IO (Rep t))
+compileStderr env expr = do
+ (fun, output) <- compile env expr
+ when (not (null output)) $
+ hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>"
+ return fun
+
+
+data StructDecl = StructDecl
+ String -- ^ name
+ String -- ^ contents
+ String -- ^ comment
+ deriving (Show)
+
+data Stmt
+ = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side
+ | SVarDeclUninit String String -- ^ type, variable name (no initialiser)
+ | SAsg String CExpr -- ^ variable name, right-hand side
+ | SBlock (Bag Stmt)
+ | SIf CExpr (Bag Stmt) (Bag Stmt)
+ | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>}
+ | SVerbatim String -- ^ no implicit ';', just printed as-is
+ deriving (Show)
+
+data CExpr
+ = CELit String -- ^ inserted as-is, assumed no parentheses needed
+ | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
+ | CEProj CExpr String -- ^ field projection: expr.field
+ | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field
+ | CEAddrOf CExpr -- ^ &expr
+ | CEIndex CExpr CExpr -- ^ expr[expr]
+ | CECall String [CExpr] -- ^ function(arg1, ..., argn)
+ | CEBinop CExpr String CExpr -- ^ expr + expr
+ | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
+ | CECast String CExpr -- ^ (<type>)<expr>
+ deriving (Show)
+
+printStructDecl :: StructDecl -> ShowS
+printStructDecl (StructDecl name contents comment) =
+ showString "typedef struct { " . showString contents . showString " } " . showString name
+ . showString ";" . (if null comment then id else showString (" // " ++ comment))
+
+printStmt :: Int -> Stmt -> ShowS
+printStmt indent = \case
+ SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
+ SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
+ SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
+ SBlock stmts
+ | null stmts -> showString "{}"
+ | otherwise ->
+ showString "{"
+ . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
+ . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
+ SIf cond b1 b2 ->
+ showString "if (" . printCExpr 0 cond . showString ") "
+ . printStmt indent (SBlock b1)
+ . (if null b2 then id else showString " else " . printStmt indent (SBlock b2))
+ SLoop typ name e1 e2 stmts ->
+ showString ("for (" ++ typ ++ " " ++ name ++ " = ")
+ . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2
+ . showString ("; " ++ name ++ "++) ")
+ . printStmt indent (SBlock stmts)
+ SVerbatim s -> showString s
+
+-- d values:
+-- * 0: top level
+-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)
+-- * 2-...: various operators (see precTable)
+-- * 80: address-of operator (&)
+-- * 98: inside unknown operator
+-- * 99: left of a field projection
+-- Unlisted operators are conservatively written with full parentheses.
+printCExpr :: Int -> CExpr -> ShowS
+printCExpr d = \case
+ CELit s -> showString s
+ CEStruct name pairs ->
+ showParen (d >= 99) $
+ showString ("(" ++ name ++ "){")
+ . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e
+ | (n, e) <- pairs])
+ . showString "}"
+ CEProj e name -> printCExpr 99 e . showString ("." ++ name)
+ CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name)
+ CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
+ CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"
+ CECall n es ->
+ showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
+ CEBinop e1 n e2 ->
+ let mprec = Map.lookup n precTable
+ p = maybe (-1) fst mprec -- precedence of this operator
+ (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments
+ in showParen (d > p) $
+ printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2
+ CEIf e1 e2 e3 ->
+ showParen (d > 0) $
+ printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3
+ CECast typ e ->
+ showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e
+ where
+ precTable = Map.fromList
+ [("||", (2, (2, 2)))
+ ,("&&", (3, (3, 3)))
+ ,("==", (4, (5, 5)))
+ ,("!=", (4, (5, 5)))
+ ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop
+ ,(">", (5, (6, 6)))
+ ,("<=", (5, (6, 6)))
+ ,(">=", (5, (6, 6)))
+ ,("+", (6, (6, 7)))
+ ,("-", (6, (6, 7)))
+ ,("*", (7, (7, 8)))
+ ,("/", (7, (7, 8)))
+ ,("%", (7, (7, 8)))]
+
+repSTy :: STy t -> String
+repSTy (STScal st) = case st of
+ STI32 -> "int32_t"
+ STI64 -> "int64_t"
+ STF32 -> "float"
+ STF64 -> "double"
+ STBool -> "uint8_t"
+repSTy t = genStructName t
+
+genStructName, genArrBufStructName :: STy t -> String
+(genStructName, genArrBufStructName) =
+ (\t -> "ty_" ++ gen t
+ ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension
+ t -> error $ "genArrBufStructName: not an array type: " ++ show t)
+ where
+ -- all tags start with a letter, so the array mangling is unambiguous.
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
+
+-- The subtrees contain structs used in the bodies of the structs in this node.
+data StructTree = TreeNode [StructDecl] [StructTree]
+ deriving (Show)
+
+-- | This function generates the actual struct declarations for each of the
+-- types in our language. It thus implicitly "documents" the layout of the
+-- types in the C translation.
+--
+-- For accumulation it is important that for struct representations of monoid
+-- types, the all-zero-bytes value corresponds to the zero value of that type.
+buildStructTree :: STy t -> StructTree
+buildStructTree topty = case topty of
+ STNil ->
+ TreeNode [StructDecl name "" com] []
+ STPair a b ->
+ TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ [buildStructTree a, buildStructTree b]
+ STEither a b -> -- 0 -> l, 1 -> r
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
+ STMaybe t -> -- 0 -> nothing, 1 -> just
+ TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ [buildStructTree t]
+ STArr n t ->
+ -- The buffer is trailed by a VLA for the actual array data.
+ -- TODO: no buffer if n = 0
+ TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") ""
+ ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com]
+ [buildStructTree t]
+ STScal _ ->
+ TreeNode [] []
+ STAccum t ->
+ TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
+ ,StructDecl name (name ++ "_buf *buf;") com]
+ [buildStructTree (fromSMTy t)]
+ where
+ name = genStructName topty
+ com = ppSTy 0 topty
+
+-- State: already-generated (skippable) struct names
+-- Writer: the structs in declaration order
+genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructTreeW (TreeNode these deps) = do
+ seen <- lift get
+ case filter ((`Set.notMember` seen) . nameOf) these of
+ [] -> pure ()
+ structs -> do
+ lift $ modify (Set.fromList (map nameOf structs) <>)
+ mapM_ genStructTreeW deps
+ tell (BList structs)
+ where
+ nameOf (StructDecl name _ _) = name
+
+genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
+genAllStructs tys =
+ let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys
+ in toList (evalState (execWriterT m) mempty)
+
+data CompState = CompState
+ { csStructs :: Set (Some STy)
+ , csTopLevelDecls :: Bag String
+ , csStmts :: Bag Stmt
+ , csNextId :: Int }
+ deriving (Show)
+
+newtype CompM a = CompM (State CompState a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runCompM :: CompM a -> (a, CompState)
+runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
+
+class Monad m => MonadNameGen m where genId :: m Int
+instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
+instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
+
+genName' :: MonadNameGen m => String -> m String
+genName' "" = genName
+genName' prefix = (prefix ++) . show <$> genId
+
+genName :: MonadNameGen m => m String
+genName = genName' "x"
+
+onlyIdGen :: IdGen.IdGen a -> CompM a
+onlyIdGen m = CompM $ do
+ i1 <- gets csNextId
+ let (res, i2) = IdGen.runIdGen' i1 m
+ modify (\s -> s { csNextId = i2 })
+ return res
+
+emit :: Stmt -> CompM ()
+emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
+
+scope :: CompM a -> CompM (a, Bag Stmt)
+scope m = do
+ stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
+ res <- m
+ innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
+ return (res, innerStmts)
+
+emitStruct :: STy t -> CompM String
+emitStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty)
+
+-- | Also returns the name of the array buffer struct
+emitArrStruct :: STy t -> CompM (String, String)
+emitArrStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty, genArrBufStructName ty)
+
+emitTLD :: String -> CompM ()
+emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
+
+nameEnv :: SList f env -> SList (Const String) env
+nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
+
+data KernelOffsets = KernelOffsets
+ { koArgOffsets :: [Int] -- ^ the function arguments
+ , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
+ , koResultOffset :: Int -- ^ the function result
+ }
+
+compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
+compileToString codeID env expr =
+ let args = nameEnv env
+ (res, s) = runCompM (compile' args expr)
+ structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
+
+ (arg_pairs, arg_metrics) =
+ unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
+ (slistZip env args))
+ (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
+ result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
+
+ offsets = KernelOffsets
+ { koArgOffsets = arg_offsets
+ , koOkResOffset = okres_offset
+ , koResultOffset = result_offset }
+ in (,offsets) . ($ "") $ compose
+ [showString "#include <stdio.h>\n"
+ ,showString "#include <stdint.h>\n"
+ ,showString "#include <stdbool.h>\n"
+ ,showString "#include <inttypes.h>\n"
+ ,showString "#include <stdlib.h>\n"
+ ,showString "#include <string.h>\n"
+ ,showString "#include <math.h>\n\n"
+ -- PRint-tag
+ ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
+
+ ,compose [printStructDecl sd . showString "\n" | sd <- structs]
+ ,showString "\n"
+
+ -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
+ ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
+ ,showString " void *ptr = malloc(n);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
+ else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
+ ,showString " return ptr;\n"
+ ,showString "}\n"
+ ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
+ ,showString " void *ptr = calloc(n, 1);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
+ else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
+ ,showString " return ptr;\n"
+ ,showString "}\n"
+ ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void free_instr(void *ptr) {\n"
+ ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
+ else id
+ ,showString " free(ptr);\n"
+ ,showString "}\n\n"
+
+ ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
+
+ ,showString $
+ "static bool typed_kernel(" ++
+ repSTy (typeOf expr) ++ " *output" ++
+ concatMap (", " ++)
+ (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
+ ") {\n"
+ ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
+ ,showString " *output = " . printCExpr 0 res . showString ";\n"
+ ,showString " return true;\n"
+ ,showString "}\n\n"
+
+ ,showString "void kernel(void *data) {\n"
+ -- Some code here assumes that we're on a 64-bit system, so let's check that
+ ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
+ else id
+ ,showString $ " const bool success = typed_kernel(" ++
+ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
+ concat (map (\((arg, typ), off) ->
+ ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
+ ++ " /* " ++ arg ++ " */")
+ (zip arg_pairs arg_offsets)) ++
+ "\n );\n"
+ ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
+ else id
+ ,showString "}\n"]
+
+-- | Takes list of metrics (alignment, sizeof).
+-- Returns (offsets, size of struct).
+computeStructOffsets :: [(Int, Int)] -> ([Int], Int)
+computeStructOffsets = go 0 0
+ where
+ go off maxal [(al, sz)] =
+ ([off], align (max maxal al) (off + sz))
+ go off maxal ((al, sz) : pairs@((al2,_):_)) =
+ first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs
+ go _ _ [] = ([], 0)
+
+-- | Assumes that this is called at the correct alignment.
+serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r
+serialise topty topval ptr off k =
+ -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
+ case (topty, topval) of
+ (STNil, ()) -> k
+ (STPair a b, (x, y)) ->
+ serialise a x ptr off $
+ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
+ (STEither a _, Left x) -> do
+ pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STEither _ b, Right y) -> do
+ pokeByteOff ptr off (1 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
+ (STMaybe _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STMaybe t, Just x) -> do
+ pokeByteOff ptr off (1 :: Word8)
+ serialise t x ptr (off + alignmentSTy t) k
+ (STArr n t, Array sh vec) -> do
+ let eltsz = sizeofSTy t
+ allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
+ pokeByteOff ptr off bufptr
+ pokeShape ptr (off + 8) n sh
+
+ pokeByteOff @Word64 bufptr 0 (2 ^ 63)
+
+ let loop i
+ | i == shapeSize sh = k
+ | otherwise =
+ serialise t (vec V.! i) bufptr (8 + i * eltsz) $
+ loop (i+1)
+ loop 0
+ (STScal sty, x) -> case sty of
+ STI32 -> pokeByteOff ptr off (x :: Int32) >> k
+ STI64 -> pokeByteOff ptr off (x :: Int64) >> k
+ STF32 -> pokeByteOff ptr off (x :: Float) >> k
+ STF64 -> pokeByteOff ptr off (x :: Double) >> k
+ STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
+ (STAccum{}, _) -> error "Cannot serialise accumulators"
+
+-- | Assumes that this is called at the correct alignment.
+deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
+deserialise topty ptr off =
+ -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
+ case topty of
+ STNil -> return ()
+ STPair a b -> do
+ x <- deserialise a ptr off
+ y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a))
+ return (x, y)
+ STEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
+ then Left <$> deserialise a ptr (off + alignmentSTy topty)
+ else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
+ STMaybe t -> do
+ tag <- peekByteOff @Word8 ptr off
+ if tag == 0
+ then return Nothing
+ else Just <$> deserialise t ptr (off + alignmentSTy t)
+ STArr n t -> do
+ bufptr <- peekByteOff @(Ptr ()) ptr off
+ sh <- peekShape ptr (off + 8) n
+ refc <- peekByteOff @Word64 bufptr 0
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
+ let eltsz = sizeofSTy t
+ arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz))
+ when (refc < 2 ^ 62) $ free bufptr
+ return arr
+ STScal sty -> case sty of
+ STI32 -> peekByteOff @Int32 ptr off
+ STI64 -> peekByteOff @Int64 ptr off
+ STF32 -> peekByteOff @Float ptr off
+ STF64 -> peekByteOff @Double ptr off
+ STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
+ STAccum{} -> error "Cannot serialise accumulators"
+
+align :: Int -> Int -> Int
+align a off = (off + a - 1) `div` a * a
+
+alignmentSTy :: STy t -> Int
+alignmentSTy = fst . metricsSTy
+
+sizeofSTy :: STy t -> Int
+sizeofSTy = snd . metricsSTy
+
+-- | Returns (alignment, sizeof)
+metricsSTy :: STy t -> (Int, Int)
+metricsSTy STNil = (1, 0)
+metricsSTy (STPair a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, align (max a1 a2) (s1 + s2))
+metricsSTy (STEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STMaybe t) =
+ let (a, s) = metricsSTy t
+ in (a, a + s) -- the union after the tag byte is aligned
+metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n)
+metricsSTy (STScal sty) = case sty of
+ STI32 -> (4, 4)
+ STI64 -> (8, 8)
+ STF32 -> (4, 4)
+ STF64 -> (8, 8)
+ STBool -> (1, 1) -- compiled to uint8_t
+metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
+
+pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
+pokeShape ptr off = go . fromSNat
+ where
+ go :: Int -> Shape n -> IO ()
+ go rank = \case
+ ShNil -> return ()
+ sh `ShCons` n -> do
+ pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64)
+ go (rank - 1) sh
+
+peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n)
+peekShape ptr off = \case
+ SZ -> return ShNil
+ SS n -> ShCons <$> peekShape ptr off n
+ <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8))
+
+compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
+compile' env = \case
+ EVar _ t i -> do
+ let Const var = slistIdx env i
+ incrementVarAlways "var" Increment t var
+ return $ CELit var
+
+ ELet _ rhs body -> do
+ var <- compileAssign "" env rhs
+ rete <- compile' (Const var `SCons` env) body
+ incrementVarAlways "let" Decrement (typeOf rhs) var
+ return rete
+
+ EPair _ a b -> do
+ name <- emitStruct (STPair (typeOf a) (typeOf b))
+ e1 <- compile' env a
+ e2 <- compile' env b
+ return $ CEStruct name [("a", e1), ("b", e2)]
+
+ EFst _ e -> do
+ let STPair _ t2 = typeOf e
+ e' <- compile' env e
+ case incrementVar "fst" Decrement t2 of
+ Nothing -> return $ CEProj e' "a"
+ Just f -> do var <- genName
+ emit $ SVarDecl True (repSTy (typeOf e)) var e'
+ f (var ++ ".b")
+ return $ CEProj (CELit var) "a"
+
+ ESnd _ e -> do
+ let STPair t1 _ = typeOf e
+ e' <- compile' env e
+ case incrementVar "snd" Decrement t1 of
+ Nothing -> return $ CEProj e' "b"
+ Just f -> do var <- genName
+ emit $ SVarDecl True (repSTy (typeOf e)) var e'
+ f (var ++ ".a")
+ return $ CEProj (CELit var) "b"
+
+ ENil _ -> do
+ name <- emitStruct STNil
+ return $ CEStruct name []
+
+ EInl _ t e -> do
+ name <- emitStruct (STEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "0"), ("l", e1)]
+
+ EInr _ t e -> do
+ name <- emitStruct (STEither t (typeOf e))
+ e2 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("r", e2)]
+
+ ECase _ (EOp _ OIf e) a b -> do
+ e1 <- compile' env e
+ (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you
+ (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SIf e1
+ (stmts2 <> pure (SAsg retvar e2))
+ (stmts3 <> pure (SAsg retvar e3))
+ return (CELit retvar)
+
+ ECase _ e a b -> do
+ let STEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ -- I know those are not variable names, but it's fine, probably
+ (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b
+ ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2
+ <> stmtsRel1
+ <> pure (SAsg retvar e2))
+ (stmts3
+ <> stmtsRel2
+ <> pure (SAsg retvar e3))))
+ return (CELit retvar)
+
+ ENothing _ t -> do
+ name <- emitStruct (STMaybe t)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ EJust _ e -> do
+ name <- emitStruct (STMaybe (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("j", e1)]
+
+ EMaybe _ a b e -> do
+ let STMaybe t = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b
+ ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2
+ <> pure (SAsg retvar e2))
+ (stmts3
+ <> stmtsRel
+ <> pure (SAsg retvar e3))))
+ return (CELit retvar)
+
+ ELNil _ t1 t2 -> do
+ name <- emitStruct (STLEither t1 t2)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ ELInl _ t e -> do
+ name <- emitStruct (STLEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
+
+ ELInr _ t e -> do
+ name <- emitStruct (STLEither t (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
+
+ ELCase _ e a b c -> do
+ let STLEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
+ (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
+ ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2 <> pure (SAsg retvar e2))
+ (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
+ (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
+ (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
+ return (CELit retvar)
+
+ EConstArr _ n t (Array sh vec) -> do
+ (strname, bufstrname) <- emitArrStruct (STArr n (STScal t))
+ tldname <- genName' "carraybuf"
+ -- Give it a refcount of _half_ the size_t max, so that it can be
+ -- incremented and decremented at will and will "never" reach anything
+ -- where something happens
+ emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++
+ "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++
+ ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
+ return (CEStruct strname
+ [("buf", CEAddrOf (CELit tldname))
+ ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))])
+
+ EBuild _ n esh efun -> do
+ shname <- compileAssign "sh" env esh
+
+ arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
+
+ idxargname <- genName' "ix"
+ (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
+
+ linivar <- genName' "li"
+ ivars <- replicateM (fromSNat n) (genName' "i")
+ emit $ SBlock $
+ pure (SVarDecl False "size_t" linivar (CELit "0"))
+ <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
+ (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx))))
+ | (ivar, dimidx) <- zip ivars [0::Int ..]]
+ (pure (SVarDecl True (repSTy (typeOf esh)) idxargname
+ (shapeTupFromLitVars n ivars))
+ <> funstmts
+ <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval))
+
+ return (CELit arrname)
+
+ -- TODO: actually generate decent code here
+ EMap _ e1 e2 -> do
+ let STArr n _ = typeOf e2
+ compile' env $
+ elet e2 $
+ EBuild ext n (EShape ext (evar IZ)) $
+ elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e1
+
+ EFold1Inner _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+
+ -- let vecwid = case commut of Commut -> 8 :: Int
+ -- Noncommut -> 1
+
+ x0name <- compileAssign "foldx0" env ex0
+ arrname <- compileAssign "foldarr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1i" arrname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, which is
+ -- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
+
+ resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ -- kvar <- if vecwid > 1 then genName' "k" else return ""
+
+ accvar <- genName' "tot"
+ pairvar <- genName' "pair" -- function input
+ (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
+
+ let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
+ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
+
+ pairstrname <- emitStruct (STPair t t)
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
+ <> funStmts
+ <> pure (SAsg accvar funres))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldx0" Decrement t x0name
+ incrementVarAlways "foldarr" Decrement (typeOf earr) arrname
+
+ return (CELit resname)
+
+ ESum1Inner _ e -> do
+ let STArr (SS n) t = typeOf e
+ argname <- compileAssign "sumarg" env e
+
+ zeroRefcountCheck (typeOf e) "sum1i" argname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, like EFold1Inner.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ let vecwid = 8 :: Int
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ kvar <- genName' "k"
+ accvar <- genName' "tot"
+ let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
+ -- we have ScalIsNumeric, so it has 0 and (+) in C
+ [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};"
+ ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $
+ pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];"
+ ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))]
+
+ incrementVarAlways "sum" Decrement (typeOf e) argname
+
+ return (CELit resname)
+
+ EUnit _ e -> do
+ e' <- compile' env e
+ let typ = STArr SZ (typeOf e)
+ strname <- emitStruct typ
+ name <- genName
+ emit $ SVarDecl True strname name (CEStruct strname
+ [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])])
+ emit $ SAsg (name ++ ".buf->refc") (CELit "1")
+ emit $ SAsg (name ++ ".buf->xs[0]") e'
+ return (CELit name)
+
+ EReplicate1Inner _ elen earg -> do
+ let STArr n t = typeOf earg
+ lenname <- compileAssign "replen" env elen
+ argname <- compileAssign "reparg" env earg
+
+ zeroRefcountCheck (typeOf earg) "replicate1i" argname
+
+ shszname <- genName' "shsz"
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray "repl1i" Malloc "rep" (SS n) t
+ (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
+ (compileArrShapeComponents n argname ++ [CELit lenname])
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]")
+ (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]"))
+
+ incrementVarAlways "repl1i" Decrement (typeOf earg) argname
+
+ return (CELit resname)
+
+ EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e
+
+ EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
+
+ EReshape _ dim esh earg -> do
+ let STArr origDim eltty = typeOf earg
+ strname <- emitStruct (STArr dim eltty)
+
+ shname <- compileAssign "reshsh" env esh
+ arrname <- compileAssign "resharg" env earg
+
+ when emitChecks $ do
+ emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
+ printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
+ printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
+ mempty
+
+ return (CEStruct strname
+ [("buf", CEProj (CELit arrname) "buf")
+ ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+
+ -- TODO: actually generate decent code here
+ EZip _ e1 e2 -> do
+ let STArr n _ = typeOf e1
+ compile' env $
+ elet e1 $
+ elet (weakenExpr WSink e2) $
+ EBuild ext n (EShape ext (evar (IS IZ))) $
+ EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ))
+
+ EFold1InnerD1 _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+ STPair _ bty = typeOf efun
+
+ x0name <- compileAssign "foldd1x0" env ex0
+ arrname <- compileAssign "foldd1arr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
+ storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "tot"
+ pairvar <- genName' "pair" -- function input
+ (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
+ arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
+ funresvar <- genName' "res"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+
+ pairstrname <- emitStruct (STPair t t)
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd1x0" Decrement t x0name
+ incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
+
+ strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
+ return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
+
+ EFold1InnerD2 _ commut efun estores ectg -> do
+ let STArr n t2 = typeOf ectg
+ STArr _ bty = typeOf estores
+
+ storesname <- compileAssign "foldd2stores" env estores
+ ctgname <- compileAssign "foldd2ctg" env ectg
+
+ zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
+ outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "acc"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
+ storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
+ ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
+ ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
+ <> ctgeltIncrStmts
+ -- we need to loop in reverse here, but we let jvar run in the
+ -- forward direction so that we can use SLoop. Note jvar is
+ -- reversed in eltidx above
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the accumulator
+ -- and the stores element. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the stores element.
+ storeseltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
+ incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
+
+ strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
+ return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
+
+ EConst _ t x -> return $ CELit $ compileScal True t x
+
+ EIdx0 _ e -> do
+ let STArr _ t = typeOf e
+ arrname <- compileAssign "" env e
+ zeroRefcountCheck (typeOf e) "idx0" arrname
+ name <- genName
+ emit $ SVarDecl True (repSTy t) name
+ (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
+ incrementVarAlways "idx0" Decrement (STArr SZ t) arrname
+ return (CELit name)
+
+ -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
+
+ EIdx _ earr eidx -> do
+ let STArr n t = typeOf earr
+ arrname <- compileAssign "ixarr" env earr
+ zeroRefcountCheck (typeOf earr) "idx" arrname
+ idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
+ then compileAssign "ixix" env eidx
+ else return "" -- won't be used in this case
+
+ when emitChecks $
+ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
+ emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]")))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
+ arrname ++ ".buf); return false;")
+ mempty
+
+ resname <- genName' "ixres"
+ emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname))
+ incrementVarAlways "idxelt" Increment t resname
+ incrementVarAlways "idx" Decrement (STArr n t) arrname
+ return (CELit resname)
+
+ EShape _ e -> do
+ let STArr n _ = typeOf e
+ t = tTup (sreplicate n tIx)
+ _ <- emitStruct t
+ name <- compileAssign "" env e
+ zeroRefcountCheck (typeOf e) "shape" name
+ resname <- genName
+ emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
+ incrementVarAlways "shape" Decrement (typeOf e) name
+ return (CELit resname)
+
+ EOp _ op (EPair _ e1 e2) -> do
+ e1' <- compile' env e1
+ e2' <- compile' env e2
+ compileOpPair op e1' e2'
+
+ EOp _ op e -> do
+ e' <- compile' env e
+ compileOpGeneral op e'
+
+ ECustom _ _ _ _ earg _ _ e1 e2 -> do
+ name1 <- compileAssign "" env e1
+ name2 <- compileAssign "" env e2
+ case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of
+ (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ (mfun1, mfun2) -> do
+ name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ maybe (return ()) ($ name1) mfun1
+ maybe (return ()) ($ name2) mfun2
+ return (CELit name)
+
+ ERecompute _ e -> compile' env e
+
+ EWith _ t e1 e2 -> do
+ actyname <- emitStruct (STAccum t)
+ name1 <- compileAssign "" env e1
+
+ zeroRefcountCheck (typeOf e1) "with" name1
+
+ emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
+ mcopy <- copyForWriting t name1
+ accname <- genName' "accum"
+ emit $ SVarDecl False actyname accname
+ (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
+ emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
+ emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
+
+ e2' <- compile' (Const accname `SCons` env) e2
+
+ resname <- genName' "acret"
+ emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
+ emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
+
+ rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
+ return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
+
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
+ let -- Add a value (s) into an existing accumulation value (d). If a sparse
+ -- component of d is encountered, s is copied there.
+ add :: SMTy a -> String -> String -> CompM ()
+ add SMTNil _ _ = return ()
+ add (SMTPair t1 t2) d s = do
+ add t1 (d++".a") (s++".a")
+ add t2 (d++".b") (s++".b")
+ add (SMTLEither t1 t2) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
+ ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
+ ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ ((if emitChecks
+ then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
+ "&&"
+ (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
+ "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
+ "return false;")
+ mempty)
+ else mempty)
+ -- note: s may have tag 0
+ <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ stmts1
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
+ stmts2 mempty))))
+ add (SMTMaybe t1) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
+ ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
+ add (SMTArr n t1) d s = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ [0 .. fromSNat n - 1] $ \j -> do
+ emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]"))
+ "!="
+ (CELit (d ++ ".sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
+ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
+ d ++ ".buf" ++
+ concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ ", " ++ s ++ ".buf" ++
+ concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ shsizename <- genName' "acshsz"
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
+ ivar <- genName' "i"
+ ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+ stmts1
+ add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
+
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
+
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
+
+ emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
+ emit $ SVerbatim $ "// compile EAccum end"
+
+ incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
+
+ return $ CEStruct (repSTy STNil) []
+
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
+ EError _ t s -> do
+ let padleft len c s' = replicate (len - length s) c ++ s'
+ escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
+ | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
+ | otherwise -> [c]
+ emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
+ case t of
+ STScal _ -> return (CELit "0")
+ _ -> do
+ name <- emitStruct t
+ return $ CEStruct name []
+
+ EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+
+ EIdx1{} -> error "Compile: not implemented: EIdx1"
+
+compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
+compileAssign prefix env e = do
+ e' <- compile' env e
+ case e' of
+ CELit name -> return name
+ _ -> do
+ name <- genName' prefix
+ emit $ SVarDecl True (repSTy (typeOf e)) name e'
+ return name
+
+data Increment = Increment | Decrement
+ deriving (Show)
+
+-- | Increment reference counts in the components of the given variable.
+incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ())
+incrementVar marker inc ty =
+ let tree = makeArrayTree ty
+ in case tree of ATNoop -> Nothing
+ _ -> Just $ \var -> incrementVar' marker inc var tree
+
+incrementVarAlways :: String -> Increment -> STy a -> String -> CompM ()
+incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty)
+
+data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array)
+ | ATNoop -- ^ don't do anything here
+ | ATProj String ArrayTree -- ^ descend one field deeper
+ | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
+ | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
+ | ATBoth ArrayTree ArrayTree -- ^ do both these paths
+
+smartATProj :: String -> ArrayTree -> ArrayTree
+smartATProj _ ATNoop = ATNoop
+smartATProj field t = ATProj field t
+
+smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
+smartATCondTag ATNoop ATNoop = ATNoop
+smartATCondTag t t' = ATCondTag t t'
+
+smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
+smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
+smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
+
+smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
+smartATBoth ATNoop t = t
+smartATBoth t ATNoop = t
+smartATBoth t t' = ATBoth t t'
+
+makeArrayTree :: STy a -> ArrayTree
+makeArrayTree STNil = ATNoop
+makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
+ (smartATProj "b" (makeArrayTree b))
+makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
+makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
+makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
+makeArrayTree (STScal _) = ATNoop
+makeArrayTree (STAccum _) = ATNoop
+
+incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
+incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
+ case inc of
+ Increment -> do
+ emit $ SVerbatim (path ++ ".buf->refc++;")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
+ Decrement -> do
+ case incrementVar (marker++".elt") Decrement eltty of
+ Nothing ->
+ if debugRefc
+ then do
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
+ else do
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
+ Just f -> do
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ shszvar <- genName' "frshsz"
+ ivar <- genName' "i"
+ ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
+ emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0"))
+ (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path)
+ ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $
+ eltDecrStmts
+ ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"])
+ mempty
+incrementVar' _ _ _ ATNoop = pure ()
+incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t
+incrementVar' marker inc path (ATCondTag t1 t2) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
+incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
+ stmts2
+ (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
+ stmts3
+ stmts1))
+incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
+
+toLinearIdx :: SNat n -> String -> String -> CExpr
+toLinearIdx SZ _ _ = CELit "0"
+toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
+toLinearIdx (SS n) arrvar idxvar =
+ CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
+ "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n)))))
+ "+" (CELit (idxvar ++ ".b"))
+
+-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
+-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
+-- fromLinearIdx (SS n) arrvar idxvar = do
+-- name <- genName
+-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]")))
+-- _
+
+data AllocMethod = Malloc | Calloc
+ deriving (Show)
+
+-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
+allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray marker method nameBase rank eltty mshsz shape = do
+ when (length shape /= fromSNat rank) $
+ error "allocArray: shape does not match rank"
+ let arrty = STArr rank eltty
+ strname <- emitStruct arrty
+ arrname <- genName' nameBase
+ shsz <- case mshsz of
+ Just e -> return e
+ Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
+ let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
+ "+"
+ (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
+ emit $ SVarDecl True strname arrname $ CEStruct strname
+ [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
+ Calloc -> CECall "calloc_instr" [nbytesExpr])
+ ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))]
+ emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
+ return arrname
+
+compileShapeQuery :: SNat n -> String -> CExpr
+compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
+compileShapeQuery (SS n) var =
+ CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
+ [("a", compileShapeQuery n var)
+ ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))]
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeSize :: SNat n -> String -> CExpr
+compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeComponents :: SNat n -> String -> [CExpr]
+compileArrShapeComponents n var =
+ [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+
+indexTupleComponents :: SNat n -> String -> [CExpr]
+indexTupleComponents = \n var -> map CELit (toList (go n var))
+ where
+ go :: SNat n -> String -> Bag String
+ go SZ _ = mempty
+ go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b")
+
+-- | Takes variable names with the innermost dimension on the right.
+shapeTupFromLitVars :: SNat n -> [String] -> CExpr
+shapeTupFromLitVars = \n -> go n . reverse
+ where
+ -- takes variables with the innermost dimension at the _head_
+ go :: SNat n -> [String] -> CExpr
+ go SZ [] = CEStruct (repSTy STNil) []
+ go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
+ go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+
+prodExpr :: [CExpr] -> CExpr
+prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
+
+compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
+compileOpGeneral op e1 = do
+ let unary cop = return @CompM $ CECall cop [e1]
+ let binary cop = do
+ name <- genName
+ emit $ SVarDecl True (repSTy (opt1 op)) name e1
+ return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ ONeg _ -> unary "-"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ ONot -> unary "!"
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIf -> do
+ name <- emitStruct (STEither STNil STNil)
+ _ <- emitStruct STNil
+ return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
+ (CEStruct name [("tag", CELit "1")])
+ ORound64 -> unary "(int64_t)round" -- ew
+ OToFl64 -> unary "(double)"
+ ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
+ OExp STF32 -> unary "expf"
+ OExp STF64 -> unary "exp"
+ OLog STF32 -> unary "logf"
+ OLog STF64 -> unary "log"
+ OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
+
+compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
+compileOpPair op e1 e2 = do
+ let binary cop = return @CompM $ CEBinop e1 cop e2
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
+ _ -> error "compileOpPair: got unary operator"
+
+-- | Bool: whether to ensure that the literal itself already has the appropriate type
+compileScal :: Bool -> SScalTy t -> ScalRep t -> String
+compileScal pedantic typ x = case typ of
+ STI32 -> (if pedantic then "(int32_t)" else "") ++ show x
+ STI64 -> (if pedantic then "(int64_t)" else "") ++ show x
+ STF32 -> show x ++ "f"
+ STF64 -> show x
+ STBool -> if x then "1" else "0"
+
+compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr
+compileExtremum nameBase opName operator env e = do
+ let STArr (SS n) t = typeOf e
+ argname <- compileAssign (nameBase ++ "arg") env e
+
+ zeroRefcountCheck (typeOf e) opName argname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, which is
+ -- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ xvar <- genName
+ redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
+ -- we have ScalIsNumeric, so it has 1 and (<) etc. in C
+ [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]"))
+ ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList
+ [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]"))
+ ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar)
+ ]
+ ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)]
+
+ incrementVarAlways nameBase Decrement (typeOf e) argname
+
+ return (CELit resname)
+
+-- | If this returns Nothing, there was nothing to copy because making a simple
+-- value copy in C already makes it suitable to write to.
+copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
+copyForWriting topty var = case topty of
+ SMTNil -> return Nothing
+
+ SMTPair a b -> do
+ e1 <- copyForWriting a (var ++ ".a")
+ e2 <- copyForWriting b (var ++ ".b")
+ case (e1, e2) of
+ (Nothing, Nothing) -> return Nothing
+ _ -> return $ Just $ CEStruct toptyname
+ [("a", fromMaybe (CELit (var++".a")) e1)
+ ,("b", fromMaybe (CELit (var++".b")) e2)]
+
+ SMTLEither a b -> do
+ (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
+ (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
+ case (e1, e2) of
+ (Nothing, Nothing) -> return Nothing
+ _ -> do
+ name <- genName
+ emit $ SVarDeclUninit toptyname name
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
+ (stmts1
+ <> pure (SAsg name (CEStruct toptyname
+ [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
+ (stmts2
+ <> pure (SAsg name (CEStruct toptyname
+ [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
+ return (Just (CELit name))
+
+ SMTMaybe t -> do
+ (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
+ case e1 of
+ Nothing -> return Nothing
+ Just e1' -> do
+ name <- genName
+ emit $ SVarDeclUninit toptyname name
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
+ (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
+ (stmts1
+ <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
+ return (Just (CELit name))
+
+ -- If there are no nested arrays, we know that a refcount of 1 means that the
+ -- whole thing is owned. Nested arrays have their own refcount, so with
+ -- nesting we'd have to check the refcounts of all the nested arrays _too_;
+ -- let's not do that. Furthermore, no sub-arrays means that the whole thing
+ -- is flat, and we can just memcpy if necessary.
+ SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do
+ name <- genName
+ shszname <- genName' "shsz"
+ emit $ SVarDeclUninit toptyname name
+
+ when debugShapes $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ emit $ SVerbatim $
+ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
+ concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ ");"
+
+ emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
+ (pure (SAsg name (CELit var)))
+ (let shbytes = fromSNat n * 8
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
+ totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
+ in BList
+ [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
+ ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
+ ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
+ ,SAsg (name ++ ".buf->refc") (CELit "1")
+ ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
+ printCExpr 0 databytes ");"])
+ return (Just (CELit name))
+
+ SMTArr n t -> do
+ shszname <- genName' "shsz"
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
+
+ let shbytes = fromSNat n * 8
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
+ totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
+
+ name <- genName
+ emit $ SVarDecl False toptyname name
+ (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
+ emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
+ emit $ SAsg (name ++ ".buf->refc") (CELit "1")
+
+ -- put the arrays in variables to cut short the not-quite-var chain
+ dstvar <- genName' "cpydst"
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
+ srcvar <- genName' "cpysrc"
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
+
+ ivar <- genName' "i"
+
+ (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]")
+ let cpye' = case cpye of
+ Just e -> e
+ Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug"
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ cpystmts
+ <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye')
+
+ return (Just (CELit name))
+
+ SMTScal _ -> return Nothing
+
+ where
+ toptyname = repSTy (fromSMTy topty)
+
+zeroRefcountCheck :: STy t -> String -> String -> CompM ()
+zeroRefcountCheck toptyp opname topvar =
+ when emitChecks $ do
+ mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
+ case mstmts of
+ Nothing -> return ()
+ Just stmts -> forM_ stmts emit
+ where
+ -- | If this returns 'Nothing', no statements need to be generated for this type.
+ go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
+ go STNil _ = empty
+ go (STPair a b) path = do
+ (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
+ return (s1 <> s2)
+ go (STEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
+ go (STMaybe a) path = do
+ ss <- go a (path++".j")
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
+ go (STArr n a) path = do
+ ivar <- genName' "i"
+ ss <- go a (path++".buf->xs["++ivar++"]")
+ shszname <- genName' "shsz"
+ let s1 = SVerbatim $
+ "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
+ "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
+ let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
+ let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
+ return (BList [s1, s2, s3])
+ go STScal{} _ = empty
+ go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
+
+ combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
+ combine (MaybeT a) (MaybeT b) = MaybeT $ do
+ x <- a
+ y <- b
+ return $ case (x, y) of
+ (Nothing, Nothing) -> Nothing
+ (Just x', Nothing) -> Just (x', mempty)
+ (Nothing, Just y') -> Just (mempty, y')
+ (Just x', Just y') -> Just (x', y')
+
+{-# NOINLINE uniqueIdGenRef #-}
+uniqueIdGenRef :: IORef Int
+uniqueIdGenRef = unsafePerformIO $ newIORef 1
+
+compose :: Foldable t => t (a -> a) -> a -> a
+compose = foldr (.) id
+
+showPtr :: Ptr a -> String
+showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) ""
+
+-- | Type-restricted.
+(^) :: Num a => a -> Int -> a
+(^) = (Prelude.^)
+
+foldl0' :: (a -> a -> a) -> a -> [a] -> a
+foldl0' _ x [] = x
+foldl0' f _ l = foldl1' f l
diff --git a/src/CHAD/Compile/Exec.hs b/src/CHAD/Compile/Exec.hs
new file mode 100644
index 0000000..5b4afc8
--- /dev/null
+++ b/src/CHAD/Compile/Exec.hs
@@ -0,0 +1,99 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
+module CHAD.Compile.Exec (
+ KernelLib,
+ buildKernel,
+ callKernelFun,
+
+ -- * misc
+ lineNumbers,
+) where
+
+import Control.Monad (when)
+import Data.IORef
+import Foreign (Ptr)
+import Foreign.Ptr (FunPtr)
+import System.Directory (removeDirectoryRecursive)
+import System.Environment (lookupEnv)
+import System.Exit (ExitCode(..))
+import System.IO (hPutStrLn, stderr)
+import System.IO.Error (mkIOError, userErrorType)
+import System.IO.Unsafe (unsafePerformIO)
+import System.Posix.DynamicLinker
+import System.Posix.Temp (mkdtemp)
+import System.Process (readProcessWithExitCode)
+
+
+debug :: Bool
+debug = False
+
+-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)
+data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ())))
+
+buildKernel :: String -> String -> IO (KernelLib, String)
+buildKernel csource funname = do
+ template <- (++ "/tmp.chad.") <$> getTempDir
+ path <- mkdtemp template
+
+ let outso = path ++ "/out.so"
+ let args = ["-O3", "-march=native"
+ ,"-shared", "-fPIC"
+ ,"-std=c99", "-x", "c"
+ ,"-o", outso, "-"
+ ,"-Wall", "-Wextra"
+ ,"-Wno-unused-variable", "-Wno-unused-but-set-variable"
+ ,"-Wno-unused-parameter", "-Wno-unused-function"
+ ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives
+ ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain
+ (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource
+
+ -- Print the source before the GCC output.
+ case ec of
+ ExitSuccess -> return ()
+ ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>"
+
+ case ec of
+ ExitSuccess -> return ()
+ ExitFailure{} -> do
+ removeDirectoryRecursive path
+ ioError (mkIOError userErrorType "chad kernel compilation failed" Nothing Nothing)
+
+ numLoaded <- atomicModifyIORef' numLoadedCounter (\n -> (n+1, n+1))
+ when debug $ hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)"
+ dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL]
+
+ removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now
+
+ ref <- newIORef =<< dlsym dl funname
+ _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1))
+ when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"
+ dlclose dl)
+ return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr)
+
+foreign import ccall "dynamic"
+ wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO ()
+
+-- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive
+{-# NOINLINE callKernelFun #-}
+callKernelFun :: KernelLib -> Ptr () -> IO ()
+callKernelFun (KernelLib ref) arg = do
+ ptr <- readIORef ref
+ wrapKernelFun ptr arg
+
+getTempDir :: IO FilePath
+getTempDir =
+ lookupEnv "TMPDIR" >>= \case
+ Just s | not (null s) -> return s
+ _ -> return "/tmp"
+
+{-# NOINLINE numLoadedCounter #-}
+numLoadedCounter :: IORef Int
+numLoadedCounter = unsafePerformIO $ newIORef 0
+
+lineNumbers :: String -> String
+lineNumbers str =
+ let lns = lines str
+ numlines = length lns
+ width = length (show numlines)
+ pad s = replicate (width - length s) ' ' ++ s
+ in unlines (zipWith (\i ln -> pad (show i) ++ " | " ++ ln) [1::Int ..] lns)
diff --git a/src/CHAD/Data.hs b/src/CHAD/Data.hs
new file mode 100644
index 0000000..8c7605c
--- /dev/null
+++ b/src/CHAD/Data.hs
@@ -0,0 +1,192 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Data (module CHAD.Data, (:~:)(Refl), If) where
+
+import Data.Functor.Product
+import Data.GADT.Compare
+import Data.GADT.Show
+import Data.Some
+import Data.Type.Bool (If)
+import Data.Type.Equality
+import Unsafe.Coerce (unsafeCoerce)
+
+import CHAD.Lemmas (Append)
+
+
+data Dict c where
+ Dict :: c => Dict c
+
+
+data SList f l where
+ SNil :: SList f '[]
+ SCons :: f a -> SList f l -> SList f (a : l)
+deriving instance (forall a. Show (f a)) => Show (SList f l)
+infixr `SCons`
+
+slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list
+slistMap _ SNil = SNil
+slistMap f (SCons x list) = SCons (f x) (slistMap f list)
+
+slistMapA :: Applicative m => (forall t. f t -> m (g t)) -> SList f list -> m (SList g list)
+slistMapA _ SNil = pure SNil
+slistMapA f (SCons x list) = SCons <$> f x <*> slistMapA f list
+
+slistZip :: SList f list -> SList g list -> SList (Product f g) list
+slistZip SNil SNil = SNil
+slistZip (x `SCons` l1) (y `SCons` l2) = Pair x y `SCons` slistZip l1 l2
+
+unSList :: (forall t. f t -> a) -> SList f list -> [a]
+unSList _ SNil = []
+unSList f (x `SCons` l) = f x : unSList f l
+
+showSList :: (forall t. Int -> f t -> String) -> SList f list -> String
+showSList _ SNil = "SNil"
+showSList f (x `SCons` l) = f 11 x ++ " `SCons` " ++ showSList f l
+
+sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
+sappend SNil l = l
+sappend (SCons x xs) l = SCons x (sappend xs l)
+
+type family Replicate n x where
+ Replicate Z x = '[]
+ Replicate (S n) x = x : Replicate n x
+
+sreplicate :: SNat n -> f t -> SList f (Replicate n t)
+sreplicate SZ _ = SNil
+sreplicate (SS n) x = x `SCons` sreplicate n x
+
+data Nat = Z | S Nat
+ deriving (Show, Eq, Ord)
+
+type N0 = Z
+type N1 = S N0
+type N2 = S N1
+type N3 = S N2
+
+data SNat n where
+ SZ :: SNat Z
+ SS :: SNat n -> SNat (S n)
+deriving instance Show (SNat n)
+
+instance GCompare SNat where
+ gcompare SZ SZ = GEQ
+ gcompare SZ _ = GLT
+ gcompare _ SZ = GGT
+ gcompare (SS n) (SS n') = gorderingLift1 (gcompare n n')
+
+instance TestEquality SNat where testEquality = geq
+instance GEq SNat where geq = defaultGeq
+instance GShow SNat where gshowsPrec = defaultGshowsPrec
+
+fromSNat :: SNat n -> Int
+fromSNat SZ = 0
+fromSNat (SS n) = succ (fromSNat n)
+
+unSNat :: SNat n -> Nat
+unSNat SZ = Z
+unSNat (SS n) = S (unSNat n)
+
+reSNat :: Nat -> Some SNat
+reSNat Z = Some SZ
+reSNat (S n) | Some n' <- reSNat n = Some (SS n')
+
+class KnownNat n where knownNat :: SNat n
+instance KnownNat Z where knownNat = SZ
+instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+
+snatKnown :: SNat n -> Dict (KnownNat n)
+snatKnown SZ = Dict
+snatKnown (SS n) | Dict <- snatKnown n = Dict
+
+type family n + m where
+ Z + m = m
+ S n + m = S (n + m)
+
+type family n - m where
+ n - Z = n
+ S n - S m = n - m
+
+snatAdd :: SNat n -> SNat m -> SNat (n + m)
+snatAdd SZ m = m
+snatAdd (SS n) m = SS (snatAdd n m)
+
+lemPlusSuccRight :: n + S m :~: S (n + m)
+lemPlusSuccRight = unsafeCoerceRefl
+
+lemPlusZero :: n + Z :~: n
+lemPlusZero = unsafeCoerceRefl
+
+data Vec n t where
+ VNil :: Vec Z t
+ (:<) :: t -> Vec n t -> Vec (S n) t
+deriving instance Show t => Show (Vec n t)
+deriving instance Eq t => Eq (Vec n t)
+deriving instance Functor (Vec n)
+deriving instance Foldable (Vec n)
+deriving instance Traversable (Vec n)
+
+vecLength :: Vec n t -> SNat n
+vecLength VNil = SZ
+vecLength (_ :< v) = SS (vecLength v)
+
+vecGenerate :: SNat n -> (forall i. SNat i -> t) -> Vec n t
+vecGenerate = \n f -> go n f SZ
+ where
+ go :: SNat n -> (forall i. SNat i -> t) -> SNat i' -> Vec n t
+ go SZ _ _ = VNil
+ go (SS n) f i = f i :< go n f (SS i)
+
+vecReplicateA :: Applicative f => SNat n -> f a -> f (Vec n a)
+vecReplicateA SZ _ = pure VNil
+vecReplicateA (SS n) gen = (:<) <$> gen <*> vecReplicateA n gen
+
+vecZipWithA :: Applicative f => (a -> b -> f c) -> Vec n a -> Vec n b -> f (Vec n c)
+vecZipWithA _ VNil VNil = pure VNil
+vecZipWithA f (x :< xs) (y :< ys) = (:<) <$> f x y <*> vecZipWithA f xs ys
+
+vecInit :: Vec (S n) a -> Vec n a
+vecInit (_ :< VNil) = VNil
+vecInit (x :< xs@(_ :< _)) = x :< vecInit xs
+
+unsafeCoerceRefl :: a :~: b
+unsafeCoerceRefl = unsafeCoerce Refl
+
+gorderingLift1 :: GOrdering a a' -> GOrdering (f a) (f a')
+gorderingLift1 GLT = GLT
+gorderingLift1 GGT = GGT
+gorderingLift1 GEQ = GEQ
+
+gorderingLift2 :: GOrdering a a' -> GOrdering b b' -> GOrdering (f a b) (f a' b')
+gorderingLift2 GLT _ = GLT
+gorderingLift2 GGT _ = GGT
+gorderingLift2 GEQ GLT = GLT
+gorderingLift2 GEQ GGT = GGT
+gorderingLift2 GEQ GEQ = GEQ
+
+data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t]
+ deriving (Show, Functor, Foldable, Traversable)
+
+-- | This instance is mostly there just for 'pure'
+instance Applicative Bag where
+ pure = BOne
+ BNone <*> _ = BNone
+ BOne f <*> b = f <$> b
+ BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b)
+ BMany bs <*> b = BMany (map (<*> b) bs)
+ BList bs <*> b = BMany (map (<$> b) bs)
+
+instance Semigroup (Bag t) where (<>) = BTwo
+instance Monoid (Bag t) where mempty = BNone
+
+data SBool b where
+ SF :: SBool False
+ ST :: SBool True
+deriving instance Show (SBool b)
diff --git a/src/CHAD/Data/VarMap.hs b/src/CHAD/Data/VarMap.hs
new file mode 100644
index 0000000..6e16b82
--- /dev/null
+++ b/src/CHAD/Data/VarMap.hs
@@ -0,0 +1,119 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Data.VarMap (
+ VarMap,
+ empty,
+ insert,
+ delete,
+ TypedIdx(..),
+ lookup,
+ disjointUnion,
+ sink1,
+ unsink1,
+ subMap,
+ superMap,
+) where
+
+import Prelude hiding (lookup)
+
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Data.Maybe (mapMaybe)
+import Data.Some
+import qualified Data.Vector.Storable as VS
+import Unsafe.Coerce
+
+import CHAD.AST.Env
+import CHAD.AST.Types
+import CHAD.AST.Weaken
+
+
+type role VarMap _ nominal -- ensure that 'env' is not phantom
+data VarMap k (env :: [Ty]) =
+ VarMap Int -- ^ Global offset; must be added to any value in the map in order to get the proper index
+ Int -- ^ Time since last cleanup
+ (Map k (Some STy, Int))
+deriving instance Show k => Show (VarMap k env)
+
+empty :: VarMap k env
+empty = VarMap 0 0 Map.empty
+
+insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env
+insert k ty idx (VarMap off interval mp) =
+ maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp)
+
+delete :: Ord k => k -> VarMap k env -> VarMap k env
+delete k (VarMap off interval mp) =
+ maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp)
+
+data TypedIdx env t = TypedIdx (STy t) (Idx env t)
+ deriving (Show)
+
+lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env))
+lookup k (VarMap off _ mp) = do
+ (Some ty, i) <- Map.lookup k mp
+ idx <- unsafeInt2idx (i + off)
+ return (Some (TypedIdx ty idx))
+
+disjointUnion :: Ord k => VarMap k env -> VarMap k env -> VarMap k env
+disjointUnion (VarMap off1 cl1 m1) (VarMap off2 cl2 m2) | off1 == off2 =
+ VarMap off1 (min cl1 cl2) (Map.unionWith (error "VarMap.disjointUnion: overlapping keys") m1 m2)
+disjointUnion vm1 vm2 = disjointUnion (cleanup vm1) (cleanup vm2)
+
+sink1 :: VarMap k env -> VarMap k (t : env)
+sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp
+
+unsink1 :: VarMap k (t : env) -> VarMap k env
+unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp
+
+subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env'
+subMap subenv =
+ let bools = let loop :: Subenv env env' -> [Bool]
+ loop SETop = []
+ loop (SEYesR sub) = True : loop sub
+ loop (SENo sub) = False : loop sub
+ in VS.fromList $ loop subenv
+ newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools
+ modify off (k, (ty, i))
+ | i + off < 0 = Nothing
+ | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map"
+ | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off)))
+ | otherwise = Nothing
+ in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp)
+
+superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env
+superMap subenv =
+ let loop :: Subenv env env' -> Int -> [Int]
+ loop SETop _ = []
+ loop (SEYesR sub) i = i : loop sub (i+1)
+ loop (SENo sub) i = loop sub (i+1)
+
+ newIndices = VS.fromList $ loop subenv 0
+ modify off (k, (ty, i))
+ | i + off < 0 = Nothing
+ | i + off >= VS.length newIndices = error "VarMap.superMap: found negative indices in map"
+ | otherwise = let j = newIndices VS.! (i + off)
+ in if j == -1 then Nothing else Just (k, (ty, j))
+
+ in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp)
+
+maybeCleanup :: VarMap k env -> VarMap k env
+maybeCleanup vm@(VarMap _ interval mp)
+ | let sz = Map.size mp
+ , sz > 0, 2 * interval >= 3 * sz
+ = cleanup vm
+maybeCleanup vm = vm
+
+cleanup :: VarMap k env -> VarMap k env
+cleanup (VarMap off _ mp) = VarMap 0 0 (Map.mapMaybe (\(t, i) -> if i + off >= 0 then Just (t, i + off) else Nothing) mp)
+
+unsafeInt2idx :: Int -> Maybe (Idx env t)
+unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n)
+ where
+ go :: Int -> Idx env t
+ go 0 = unsafeCoerce IZ
+ go n = unsafeCoerce (IS (go (n-1)))
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
new file mode 100644
index 0000000..595d3c7
--- /dev/null
+++ b/src/CHAD/Drev.hs
@@ -0,0 +1,1583 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module CHAD.Drev (
+ drev,
+ freezeRet,
+ CHADConfig(..),
+ defaultConfig,
+ Storage(..),
+ Descr(..),
+ Select,
+) where
+
+import Data.Functor.Const
+import Data.Some
+import Data.Type.Equality (type (==), testEquality)
+
+import CHAD.Analysis.Identity (ValId(..), validSplitEither)
+import CHAD.AST
+import CHAD.AST.Bindings
+import CHAD.AST.Count
+import CHAD.AST.Env
+import CHAD.AST.Sparse
+import CHAD.AST.Weaken.Auto
+import CHAD.Data
+import qualified CHAD.Data.VarMap as VarMap
+import CHAD.Data.VarMap (VarMap)
+import CHAD.Drev.Accum
+import CHAD.Drev.EnvDescr
+import CHAD.Drev.Types
+import CHAD.Lemmas
+
+
+------------------------------ TAPES AND BINDINGS ------------------------------
+
+type family Tape binds where
+ Tape '[] = TNil
+ Tape (t : ts) = TPair t (Tape ts)
+
+tapeTy :: SList STy binds -> STy (Tape binds)
+tapeTy SNil = STNil
+tapeTy (SCons t ts) = STPair t (tapeTy ts)
+
+bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
+ -> binds :> env2 -> Ex env2 (Tape tapebinds)
+bindingsCollectTape SNil SETop _ = ENil ext
+bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
+ EPair ext (EVar ext t (w @> IZ))
+ (bindingsCollectTape binds sub (w .> WSink))
+bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
+ bindingsCollectTape binds sub (w .> WSink)
+
+-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
+-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
+-- bindingsCollectTape' binds sub w
+-- | Refl <- lemAppendNil @binds
+-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))
+
+-- In order from large to small: i.e. in reverse order from what we want,
+-- because in a Bindings, the head of the list is the bottom-most entry.
+type family TapeUnfoldings binds where
+ TapeUnfoldings '[] = '[]
+ TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts
+
+type family Reverse l where
+ Reverse '[] = '[]
+ Reverse (t : ts) = Append (Reverse ts) '[t]
+
+-- An expression that is always 'snd'
+data UnfExpr env t where
+ UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t
+
+fromUnfExpr :: UnfExpr env t -> Ex env t
+fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ)
+
+-- - A bunch of 'snd' expressions taking us from knowing that there's a
+-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix
+-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in
+-- the environment.
+-- - In the extended environment, another bunch of let bindings (these are
+-- 'fst' expressions, but no need to know that statically) that project the
+-- fsts out of what we introduced above, one for each type in 'ts'.
+data Reconstructor env ts =
+ Reconstructor
+ (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts)))
+ (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts)
+
+ssnoc :: SList f ts -> f t -> SList f (Append ts '[t])
+ssnoc SNil a = SCons a SNil
+ssnoc (SCons t ts) a = SCons t (ssnoc ts a)
+
+sreverse :: SList f ts -> SList f (Reverse ts)
+sreverse SNil = SNil
+sreverse (SCons t ts) = ssnoc (sreverse ts) t
+
+stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts)
+stapeUnfoldings SNil = SNil
+stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts)
+
+-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one.
+shiftUnfolder
+ :: STy t
+ -> SList STy ts
+ -> Bindings UnfExpr (Tape ts : env) list
+ -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts])
+shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts))
+shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) =
+ -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order
+ -- to expand an 'Append' in the types so that things simplify just enough.
+ -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list
+ -- of bindings produced by 'b'. We want to conclude from this that
+ -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know
+ -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after
+ -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step.
+ BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t
+ BPush{} -> UnfExSnd itemTy t)
+
+growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts)
+growRecon t ts (Reconstructor unfbs bs)
+ | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts])
+ , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env)
+ , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env
+ = Reconstructor
+ (shiftUnfolder t ts unfbs)
+ -- Add a 'fst' at the bottom of the builder stack.
+ -- First we have to weaken most of 'bs' to skip one more binding in the
+ -- unfolder stack above it.
+ (BPush (fst (weakenBindingsE
+ (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil))
+ (WSink :: env :> (Tape (t : ts) : env))) bs))
+ (t
+ ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $
+ wSinks @(Tape (t : ts) : env)
+ (sappend ts
+ (sappend (sappend (sreverse (stapeUnfoldings ts))
+ (SCons (tapeTy ts) SNil))
+ SNil))
+ @> IZ))
+
+buildReconstructor :: SList STy ts -> Reconstructor env ts
+buildReconstructor SNil = Reconstructor BTop BTop
+buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)
+
+-- STRATEGY FOR reconstructBindings
+--
+-- binds = []
+-- e : ()
+--
+-- binds = [c]
+-- e : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst e : c
+--
+-- binds = [b, c]
+-- e : (b, (c, ()))
+-- x1 = snd e : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst x1 : c
+-- y2 = fst x2 : b
+--
+-- binds = [a, b, c]
+-- e : (a, (b, (c, ())))
+-- x2 = snd e : (b, (c, ()))
+-- x1 = snd x2 : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst x1 : c
+-- y2 = fst x2 : b
+-- y3 = fst x3 : a
+
+-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all
+-- the things in the list 'binds', we want to create a let stack that extracts
+-- all values from that tuple and in effect "restores" the environment
+-- described by 'binds'. The idea is that elsewhere, we took a slice of the
+-- environment and saved it all in a tuple to be restored later. We
+-- incidentally also add a bunch of additional bindings, namely 'Reverse
+-- (TapeUnfoldings binds)', so the calling code just has to skip those in
+-- whatever it wants to do.
+reconstructBindings :: SList STy binds
+ -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
+ ,SList STy (Reverse (TapeUnfoldings binds)))
+reconstructBindings binds =
+ (\tape -> let Reconstructor unf build = buildReconstructor binds
+ in fst $ weakenBindingsE (WIdx tape)
+ (bconcat (mapBindings fromUnfExpr unf) build)
+ ,sreverse (stapeUnfoldings binds))
+
+
+---------------------------------- DERIVATIVES ---------------------------------
+
+d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
+d1op (OAdd t) e = EOp ext (OAdd t) e
+d1op (OMul t) e = EOp ext (OMul t) e
+d1op (ONeg t) e = EOp ext (ONeg t) e
+d1op (OLt t) e = EOp ext (OLt t) e
+d1op (OLe t) e = EOp ext (OLe t) e
+d1op (OEq t) e = EOp ext (OEq t) e
+d1op ONot e = EOp ext ONot e
+d1op OAnd e = EOp ext OAnd e
+d1op OOr e = EOp ext OOr e
+d1op OIf e = EOp ext OIf e
+d1op ORound64 e = EOp ext ORound64 e
+d1op OToFl64 e = EOp ext OToFl64 e
+d1op (ORecip t) e = EOp ext (ORecip t) e
+d1op (OExp t) e = EOp ext (OExp t) e
+d1op (OLog t) e = EOp ext (OLog t) e
+d1op (OIDiv t) e = EOp ext (OIDiv t) e
+d1op (OMod t) e = EOp ext (OMod t) e
+
+-- | Both primal and dual must be duplicable expressions
+data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
+ | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
+
+d2op :: SOp a t -> D2Op a t
+d2op op = case op of
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
+ OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
+ ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
+ ONot -> Linear $ \_ -> ENil ext
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OIf -> Linear $ \_ -> ENil ext
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ OToFl64 -> Linear $ \_ -> ENil ext
+ ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
+ OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
+ OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
+ d2opUnArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TScal a) t)
+ -> D2Op (TScal a) t
+ d2opUnArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> ENil ext
+ STI64 -> Linear $ \_ -> ENil ext
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> ENil ext
+
+ d2opBinArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
+ -> D2Op (TPair (TScal a) (TScal a)) t
+ d2opBinArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+
+ floatingD2 :: ScalIsFloating a ~ True
+ => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
+ floatingD2 STF32 k = k
+ floatingD2 STF64 k = k
+
+ integralD2 :: ScalIsIntegral a ~ True
+ => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r
+ integralD2 STI32 k = k
+ integralD2 STI64 k = k
+
+desD1E :: Descr env sto -> SList STy (D1E env)
+desD1E = d1e . descrList
+
+-- d1W :: env :> env' -> D1E env :> D1E env'
+-- d1W WId = WId
+-- d1W WSink = WSink
+-- d1W (WCopy w) = WCopy (d1W w)
+-- d1W (WPop w) = WPop (d1W w)
+-- d1W (WThen u w) = WThen (d1W u) (d1W w)
+
+conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
+conv1Idx IZ = IZ
+conv1Idx (IS i) = IS (conv1Idx i)
+
+data Idx2 env sto t
+ = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
+ | Idx2Di (Idx (Select env sto "discr") t)
+
+conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
+conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ
+conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ
+conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ
+conv2Idx (DPush des (_, _, SAccum)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)
+ Idx2Me j -> Idx2Me j
+ Idx2Di j -> Idx2Di j
+conv2Idx (DPush des (_, _, SMerge)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac j
+ Idx2Me j -> Idx2Me (IS j)
+ Idx2Di j -> Idx2Di j
+conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac j
+ Idx2Me j -> Idx2Me j
+ Idx2Di j -> Idx2Di (IS j)
+conv2Idx DTop i = case i of {}
+
+opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+opt2UnSparse = go . opt2
+ where
+ go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+ go (STScal STI32) SpAbsent = \_ -> ENil ext
+ go (STScal STI64) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
+ go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ go (STScal STBool) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpScal = id
+ go (STScal STF64) SpScal = id
+ go STNil _ = \_ -> ENil ext
+ go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
+ go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
+
+
+----------------------------------- SPARSITY -----------------------------------
+
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal STF32) SpScal _ e = e
+expandSparse (STScal STF64) SpScal _ e = e
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+
+subenvPlus :: SBool req1 -> SBool req2
+ -> SList SMTy env
+ -> SubenvS env env1 -> SubenvS env env2
+ -> (forall env3. SubenvS env env3
+ -> Injection req1 (Tup env1) (Tup env3)
+ -> Injection req2 (Tup env2) (Tup env3)
+ -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
+ -> r)
+ -> r
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
+
+subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SENo sub3) s31 s32 pl
+
+subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ Noinj
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+
+subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
+ subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
+ k sub3 minj13 minj23 (flip pl)
+
+subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
+ k (SEYes sp3 sub3)
+ (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (tinj13 e1b))
+ (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
+ \e2 -> eunPair e2 $ \_ e2a e2b ->
+ EPair ext (inj23 e2a) (tinj23 e2b))
+ (\e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))))
+
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+
+--------------------------------- ACCUMULATORS ---------------------------------
+
+fromArrayValId :: Maybe (ValId t) -> Maybe Int
+fromArrayValId (Just (VIArr i _)) = Just i
+fromArrayValId _ = Nothing
+
+accumPromote :: forall dt env sto proxy r.
+ proxy dt
+ -> Descr env sto
+ -> (forall stoRepl envPro.
+ (Select env stoRepl "merge" ~ '[])
+ => Descr env stoRepl
+ -- ^ A revised environment description that switches
+ -- arrays (used in the OccEnv) that are currently on
+ -- "merge" storage, to "accum" storage.
+ -> SList STy envPro
+ -- ^ New entries on top of the original dual environment,
+ -- that house the accumulators for the promoted arrays in
+ -- the original environment.
+ -> Subenv (Select env sto "merge") envPro
+ -- ^ The promoted entries were merge entries in the
+ -- original environment.
+ -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum"))
+ -- ^ All entries that were accumulators are still
+ -- accumulators.
+ -> VarMap Int (D2AcE (Select env stoRepl "accum"))
+ -- ^ Accumulator map for _only_ the the newly allocated
+ -- accumulators.
+ -> (forall shbinds.
+ SList STy shbinds
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ -- ^ A weakening that converts a computation in the
+ -- revised environment to one in the original environment
+ -- extended with some accumulators.
+ -> r)
+ -> r
+accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId)
+accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
+ -- Accumulators are left as-is
+ SAccum ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ envpro
+ prosub
+ (SEYesR accrevsub)
+ (VarMap.sink1 accumMap)
+ (\shbinds ->
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
+ (#pro :++: #d :++: #shb :++: #acc :++: #tl)
+ .> WCopy (wf shbinds)
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ (#d :++: #shb :++: #acc :++: #tl)
+ (#acc :++: (#d :++: #shb :++: #tl)))
+
+ SMerge -> case t of
+ -- Discrete values are left as-is
+ _ | isDiscrete t ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
+ envpro
+ (SENo prosub)
+ accrevsub
+ accumMap'
+ wf
+
+ -- Values with "merge" storage are promoted to an accumulator in envPro
+ _ ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ (t `SCons` envpro)
+ (SEYesR prosub)
+ (SENo accrevsub)
+ (let accumMap' = VarMap.sink1 accumMap
+ in case fromArrayValId vid of
+ Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap'
+ Nothing -> accumMap')
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ -- Discrete values are left as-is, nothing to do
+ SDiscr ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
+ envpro
+ prosub
+ accrevsub
+ accumMap
+ wf
+ where
+ isDiscrete :: STy t' -> Bool
+ isDiscrete = \case
+ STNil -> True
+ STPair a b -> isDiscrete a && isDiscrete b
+ STEither a b -> isDiscrete a && isDiscrete b
+ STLEither a b -> isDiscrete a && isDiscrete b
+ STMaybe a -> isDiscrete a
+ STArr _ a -> isDiscrete a
+ STScal st -> case st of
+ STI32 -> True
+ STI64 -> True
+ STF32 -> False
+ STF64 -> False
+ STBool -> True
+ STAccum{} -> False
+
+
+---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
+
+data Ret env0 sto sd t =
+ forall shbinds tapebinds contribs.
+ Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (Ex (Append shbinds (D1E env0)) (D1 t))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+deriving instance Show (Ret env0 sto sd t)
+
+type data TyTyPair = MkTyTyPair Ty Ty
+
+data SingleRet env0 sto (pair :: TyTyPair) =
+ forall shbinds tapebinds.
+ SingleRet
+ (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (RetPair env0 sto (D1E env0) shbinds tapebinds pair)
+
+-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
+-- -> Subenv shbinds tapebinds
+-- -> Ex (Append shbinds (D1E env0)) (D1 t)
+-- -> SubenvS (D2E (Select env0 sto "merge")) contribs
+-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+-- -> SingleRet env0 sto (MkTyTyPair sd t)
+-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
+-- {-# COMPLETE Ret1 #-}
+
+data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
+ RetPair :: forall sd t contribs -- existentials
+ env0 sto env shbinds tapebinds. -- universals
+ Ex (Append shbinds env) (D1 t)
+ -> SubenvS (D2E (Select env0 sto "merge")) contribs
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
+deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
+
+data Rets env0 sto env list =
+ forall shbinds tapebinds.
+ Rets (Bindings Ex env shbinds)
+ (Subenv shbinds tapebinds)
+ (SList (RetPair env0 sto env shbinds tapebinds) list)
+deriving instance Show (Rets env0 sto env list)
+
+toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
+toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+
+weakenRetPair :: SList STy shbinds -> env :> env'
+ -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
+weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
+
+weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
+weakenRets w (Rets binds tapesub list) =
+ let (binds', _) = weakenBindingsE w binds
+ in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
+
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
+ Descr env0 sto
+ -> SList f b1 -> SList f b2
+ -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
+ | Refl <- lemAppendAssoc @b2 @b1 @env =
+ RetPair e1 sub
+ (weakenExpr (autoWeak
+ (#d (auto1 @sd)
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ e2)
+
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
+retConcat _ SNil = Rets BTop SETop SNil
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
+ | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
+ <- weakenRets (sinkWithBindings e0) (retConcat descr list)
+ , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
+ , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
+ = Rets (bconcat e0 binds)
+ (subenvConcat subtape subtape2)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
+ sub
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
+ subtape subtape2)
+ pairs))
+
+freezeRet :: Descr env sto
+ -> Ret env sto (D2 t) t
+ -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
+ let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0
+ e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
+ in letBinds e0' $
+ EPair ext
+ (weakenExpr wInsertD2Ac e1)
+ (ELet ext (weakenExpr (autoWeak library
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
+ (#shbinds :++: #d :++: #d2ace :++: #tl))
+ e2') $
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
+
+
+---------------------------- THE CHAD TRANSFORMATION ---------------------------
+
+drev :: forall env sto sd t.
+ (?config :: CHADConfig)
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2 t) sd
+ -> Expr ValId env t -> Ret env sto sd t
+drev des _ sd | isAbsent sd =
+ \e ->
+ Ret BTop
+ SETop
+ (drevPrimal des e)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+drev _ _ SpAbsent = error "Absent should be isAbsent"
+
+drev des accumMap (SpSparse sd) =
+ \e ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ e1
+ sub'
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
+ (inj2 (ENil ext))
+ (inj1 (weakenExpr (WCopy WSink) e2)))
+ }
+
+drev des accumMap sd = \case
+ EVar _ t i ->
+ case conv2Idx des i of
+ Idx2Ac accI ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvNone (d2e (select SMerge des)))
+ (let ty = applySparse sd (d2M t)
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+
+ Idx2Me tupI ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvOnehot (d2e (select SMerge des)) tupI sd)
+ (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
+
+ Idx2Di _ ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ ELet _ (rhs :: Expr _ _ a) body
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
+ , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
+ , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
+ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
+ Ret (bconcat (rhs0 `bpush` rhs1) body0')
+ (subenvConcat subtapeRHS subtapeBody)
+ (weakenExpr wbody0' body1)
+ subBoth
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #tl)
+ (#d :++: (#body :++: #rhs) :++: #tl))
+ body2) $
+ ELet ext
+ (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
+ plus_RHS_Body
+ (EVar ext (contribTupTy des subRHS) IZ)
+ (EFst ext (EVar ext bodyResType (IS IZ))))
+
+ EPair _ a b
+ | SpPair sd1 sd2 <- sd
+ , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
+ , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ Ret binds
+ subtape
+ (EPair ext a1 b1)
+ subBoth
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (contribTupTy des subA) (IS IZ))
+ (EVar ext (contribTupTy des subB) IZ))
+
+ EFst _ e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
+ , STPair t1 _ <- typeOf e ->
+ Ret e0
+ subtape
+ (EFst ext e1)
+ sub
+ (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
+ weakenExpr (WCopy WSink) e2)
+
+ ESnd _ e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
+ , STPair _ t2 <- typeOf e ->
+ Ret e0
+ subtape
+ (ESnd ext e1)
+ sub
+ (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ -- Don't need to handle ENil, because its cotangent is always absent!
+ -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)
+
+ EInl _ t2 e
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ (EInl ext (d1 t2) e1)
+ sub'
+ (ELCase ext
+ (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
+ (inj2 $ ENil ext)
+ (inj1 $ weakenExpr (WCopy WSink) e2)
+ (EError ext (contribTupTy des sub') "inl<-dinr"))
+
+ EInr _ t1 e
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ (EInr ext (d1 t1) e1)
+ sub'
+ (ELCase ext
+ (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
+ (inj2 $ ENil ext)
+ (EError ext (contribTupTy des sub') "inr<-dinl")
+ (inj1 $ weakenExpr (WCopy WSink) e2))
+
+ ECase _ e (a :: Expr _ _ t) b
+ | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
+ , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
+ , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
+ , let (bindids1, bindids2) = validSplitEither (extOf e)
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
+ <- drevScoped des accumMap t1 storage1 bindids1 sd a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
+ <- drevScoped des accumMap t2 storage2 bindids2 sd b
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
+ , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
+ , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
+ , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , let tapeA = tapeTy subtapeListA
+ , let tapeB = tapeTy subtapeListB
+ , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
+ (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
+ (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
+ , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0
+ , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0
+ , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
+ , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
+ , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
+ , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
+ , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
+ , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
+ ->
+ subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
+ Ret (e0 `bpush` ECase ext e1
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))
+ (SEYesR subtapeE)
+ (EFst ext (EVar ext tPrimal IZ))
+ subOut
+ (elet
+ (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA
+ in letBinds (rebinds IZ) $
+ ELet ext
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #ta0 subtapeListA
+ &. #prea0 prerebinds
+ &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
+ &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #ta0 :++: #tl)
+ (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
+ a2) $
+ EPair ext (sAB_A $ EFst ext (evar IZ))
+ (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB
+ in letBinds (rebinds IZ) $
+ ELet ext
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #tb0 subtapeListB
+ &. #preb0 prerebinds
+ &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
+ &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #tb0 :++: #tl)
+ (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
+ b2) $
+ EPair ext (sAB_B $ EFst ext (evar IZ))
+ (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
+ plus_AB_E
+ (EFst ext (evar IZ))
+ (ELet ext (ESnd ext (evar IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2))
+
+ EConst _ t val ->
+ Ret BTop
+ SETop
+ (EConst ext t val)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ EOp _ op e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
+ case d2op op of
+ Linear d2opfun ->
+ Ret e0
+ subtape
+ (d1op op e1)
+ sub
+ (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
+ (weakenExpr (WCopy WSink) e2))
+ Nonlinear d2opfun ->
+ Ret (e0 `bpush` e1)
+ (SEYesR subtape)
+ (d1op op $ EVar ext (d1 (typeOf e)) IZ)
+ sub
+ (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
+ (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
+ (weakenExpr (WCopy (wSinks' @[_,_])) e2))
+
+ ECustom _ _ tb _ srce pr du a b
+ -- allowed to ignore a2 because 'a' is the part of the input that is inactive
+ | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
+ case isDense (d2M (typeOf srce)) sd of
+ Just Refl ->
+ Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a)
+ `bpush` weakenExpr WSink b1
+ `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)
+ `bpush` ESnd ext (EVar ext (typeOf pr) IZ))
+ (SEYesR (SENo (SENo (SENo bsubtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
+ Nothing ->
+ Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a)
+ `bpush` weakenExpr WSink b1
+ `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
+ (SEYesR (SENo (SENo bsubtape)))
+ (EFst ext (EVar ext (typeOf pr) IZ))
+ bsub
+ (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape
+ ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent
+ (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
+ (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
+ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)
+
+ ERecompute _ e ->
+ deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
+ let smallE = unsafeWeakenWithSubenv usedSub e in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
+ let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
+ Ret (collectBindings (desD1E des) subD1eUsed)
+ (subenvAll (desD1E usedDes))
+ (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
+ (subenvCompose subMergeUsed' sub)
+ (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
+ weakenExpr
+ (autoWeak (#d (auto1 @sd)
+ &. #shbinds (bindingsBinds e0)
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #d1env (desD1E usedDes)
+ &. #tl' (d2ace (select SAccum usedDes))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
+ (#shbinds :++: #d :++: #d1env :++: #tl))
+ e2)
+ }
+
+ EError _ t s ->
+ Ret BTop
+ SETop
+ (EError ext (d1 t) s)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ EConstArr _ n t val ->
+ Ret BTop
+ SETop
+ (EConstArr ext n t val)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty)
+ | SpArr @_ @sdElt sdElt <- sd
+ , let eltty = typeOf ef
+ , shty :: STy shty <- tTup (sreplicate ndim tIx)
+ , Refl <- indexTupD1Id ndim ->
+ drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
+ let library = #ix (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e provars)
+ &. #d1env (desD1E des)
+ &. #d (auto1 @sdElt)
+ &. #tape (auto1 @e_tape)
+ &. #pro (d2ace provars)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim e_tape)) in
+ Ret (proPrimalBinds
+ `bpush` weakenExpr (wSinks (d1e provars))
+ (EBuild ext ndim
+ (drevPrimal des she)
+ (letBinds e0 $
+ EPair ext e1 e1tape))
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ))
+ (SEYesR (SENo (subenvAll (d1e provars))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ)))
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub)
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in
+ ESnd ext $
+ wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $
+ EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv)
+ (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
+ e2)
+
+ EMap _ ef (earr :: Expr _ _ (TArr n a))
+ | SpArr sdElt <- sd
+ , let STArr ndim t1 = typeOf earr
+ t2 = typeOf ef ->
+ drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 ->
+ case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds
+ ttape = typeOf ef1tape
+ library = #d1env (desD1E des)
+ &. #a0 (bindingsBinds ea0)
+ &. #atapebinds (subList (bindingsBinds ea0) easubtape)
+ &. #propr (d1e provars)
+ &. #x (d1 t1 `SCons` SNil)
+ &. #parr (STArr ndim (d1 t1) `SCons` SNil)
+ &. #tapearr (STArr ndim ttape `SCons` SNil)
+ &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil)
+ &. #dy (applySparse sdElt (d2 t2) `SCons` SNil)
+ &. #tape (ttape `SCons` SNil)
+ &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #pro (d2ace provars)
+ in
+ subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a ->
+ Ret (bconcat ea0 proPrimalBinds'
+ `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1
+ `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env))
+ (letBinds ef0 $
+ EPair ext ef1 ef1tape))
+ (EVar ext (STArr ndim (d1 t1)) IZ)
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ))
+ (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars))))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ)))
+ subfa
+ (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout) $
+ emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $
+ elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv)
+ (#tape :++: #dy :++: #dytape :++: #pro :++: layout))
+ ef2)
+ (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ))
+ (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $
+ plus_f_a
+ (ESnd ext (evar IZ))
+ (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout))
+ (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ))
+ ea2)))
+ }
+
+ EFold1Inner _ commut origef ex₀ earr
+ | SpArr @_ @sdElt sdElt <- sd
+ , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
+ , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
+ drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in
+ let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape)
+ bogTy = STArr (SS ndim) bogEltTy
+ primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil)
+ &. #parr (auto1 @(TArr (S n) (D1 elt)))
+ &. #px₀ (auto1 @(D1 elt))
+ &. #px (auto1 @(D1 elt))
+ &. #pzi (auto1 @(ZeroInfo (D2 elt)))
+ &. #primal (primalTy `SCons` SNil)
+ &. #darr (auto1 @(TArr n sdElt))
+ &. #d (auto1 @(D2 elt))
+ &. #x₀abinds (bindingsBinds bindsx₀a)
+ &. #fbinds (bindingsBinds ef0)
+ &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
+ &. #ftape (auto1 @ef_tape)
+ &. #bogelt (bogEltTy `SCons` SNil)
+ &. #propr (d1e provars)
+ &. #d1env (desD1E des)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace provars)
+ &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
+ wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a proPrimalBinds'
+ `bpush` weakenExpr wOverPrimalBindings ex₀1
+ `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
+ `bpush` EFold1InnerD1 ext commut
+ (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
+ weakenExpr (autoWeak library (#xy :++: #d1env) layout)
+ (letBinds ef0 $
+ EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
+ ef1
+ (EPair ext
+ (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ))
+ ef1tape)))
+ (EVar ext (d1 eltty) (IS (IS IZ)))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars)))))))
+ (EFst ext (EVar ext primalTy IZ))
+ subx₀af
+ (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout1) $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $
+ let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in
+ expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
+ weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2)
+ (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $
+ plus_x₀a_f
+ (plus_x₀_a
+ (elet (EIdx0 ext
+ (EFold1Inner ext Commut
+ (let t = STPair (d2 eltty) (d2 eltty)
+ in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
+ (eflatten (EFst ext (EFst ext (evar IZ)))))) $
+ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
+ ex₀2)
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (ESnd ext (evar IZ)))
+
+ EUnit _ e
+ | SpArr sdElt <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
+ Ret e0
+ subtape
+ (EUnit ext e1)
+ sub
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ EReplicate1Inner _ en e
+ -- We're allowed to differentiate 'en' as primal-only here because its output is discrete.
+ | SpArr sdElt <- sd
+ , let STArr ndim eltty = typeOf e ->
+ -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero.
+ sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 ->
+ Ret binds
+ subtape
+ (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
+ sub
+ (ELet ext (EFold1Inner ext Commut
+ (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty))
+ in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
+ (inj2 (ENil ext))
+ (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+ }
+
+ EIdx0 _ e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e
+ , STArr _ t <- typeOf e ->
+ Ret e0
+ subtape
+ (EIdx0 ext e1)
+ sub
+ (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
+ {-
+ EIdx1 _ e ei
+ -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
+ | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
+ , STArr (SS n) eltty <- typeOf e ->
+ Ret (binds `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))
+ (SEYesR (SENo subtape))
+ (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
+ (weakenExpr (WSink .> WSink) ei1))
+ sub
+ (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (d2 eltty)) (IS IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+ -}
+
+ EIdx _ e ei
+ -- We're allowed to differentiate ei as primal because its output is discrete.
+ | STArr n eltty <- typeOf e
+ , Refl <- indexTupD1Id n
+ , let tIxN = tTup (sreplicate n tIx) ->
+ sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 ->
+ Ret (binds `bpush` e1
+ `bpush` EShape ext (EVar ext (typeOf e1) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))
+ (SEYesR (SEYesR (SENo subtape)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ sub
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
+ EShape _ e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
+ , Refl <- indexTupD1Id n ->
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
+ (ENil ext)
+
+ ESum1Inner _ e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , STArr (SS n) t <- typeOf e ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ))
+ (SEYesR (SENo subtape))
+ (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
+ sub
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
+
+ EReshape _ n esh e
+ | SpArr sd' <- sd
+ , STArr orign t <- typeOf e
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , Refl <- indexTupD1Id n ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ))
+ (SEYesR (SENo subtape))
+ (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh))
+ (EVar ext (STArr orign (d1 t)) (IS IZ)))
+ sub
+ (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
+ EZip _ a b
+ | SpArr sd' <- sd
+ , STArr n t1 <- typeOf a
+ , STArr _ t2 <- typeOf b ->
+ splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE ->
+ case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons`
+ toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of
+ { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ Ret binds
+ subtape
+ (EZip ext a1 b1)
+ subBoth
+ (case pairSplitE of
+ Left Refl ->
+ let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2)
+ (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2)
+ Right f -> f IZ $ \wrapPair pick1 pick2 ->
+ elet (emap (wrapPair (EPair ext pick1 pick2))
+ (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2)
+ (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2))
+ }
+
+ ENothing{} -> err_unsupported "ENothing"
+ EJust{} -> err_unsupported "EJust"
+ EMaybe{} -> err_unsupported "EMaybe"
+ ELNil{} -> err_unsupported "ELNil"
+ ELInl{} -> err_unsupported "ELInl"
+ ELInr{} -> err_unsupported "ELInr"
+ ELCase{} -> err_unsupported "ELCase"
+
+ EWith{} -> err_accum
+ EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
+
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+
+ where
+ err_accum = error "Accumulator operations unsupported in the source program"
+ err_monoid = error "Monoid operations unsupported in the source program"
+ err_unsupported s = error $ "CHAD: unsupported " ++ s
+ err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program"
+
+ contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
+ contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `bpush` e1
+ `bpush` extremum (EVar ext at IZ))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
+data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
+
+data RetScoped env0 sto a s sd t =
+ forall shbinds tapebinds contribs sa.
+ RetScoped
+ (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
+ (Subenv (Append shbinds '[D1 a]) tapebinds)
+ (Ex (Append shbinds (D1E (a : env0))) (D1 t))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ -- ^ merge contributions to the _enclosing_ merge environment
+ (Sparse (D2 a) sa)
+ -- ^ contribution to the argument
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) sa)))
+ -- ^ the merge contributions, plus the cotangent to the argument
+ -- (if there is any)
+deriving instance Show (RetScoped env0 sto a s sd t)
+
+drevScoped :: forall a s env sto sd t.
+ (?config :: CHADConfig)
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> STy a -> Storage s -> Maybe (ValId a)
+ -> Sparse (D2 t) sd
+ -> Expr ValId (a : env) t
+ -> RetScoped env sto a s sd t
+drevScoped des accumMap argty argsto argids sd expr = case argsto of
+ SMerge
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ case sub of
+ SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
+ SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
+
+ SAccum
+ | chcSmartWith ?config
+ , Just (VIArr i _) <- argids
+ , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
+ , Just Refl <- testEquality foundTy (STAccum (d2M argty))
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ -- Our contribution to the binding's cotangent _here_ is zero (absent),
+ -- because we're contributing to an earlier binding of the same value
+ -- instead.
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
+ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
+ ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
+ weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: #body :++: #tl))
+ (EPair ext e2 (ENil ext))
+
+ | let accumMap' = case argids of
+ Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
+ _ -> VarMap.sink1 accumMap
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ let library = #d (auto1 @sd)
+ &. #p (auto1 @(D1 a))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des))
+ in
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
+ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ weakenExpr (autoWeak library
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: (#body :++: #p) :++: #tl))
+ e2
+
+ SDiscr
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+
+drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
+ => Descr env sto
+ -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> (STy a, Storage s)
+ -> Sparse (D2 t) dt
+ -> Expr ValId (a : env) t
+ -> (forall provars shbinds tape d2a'.
+ SList STy provars
+ -> Subenv (D2E (Select env sto "merge")) (D2E provars)
+ -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator)
+ -> Bindings Ex (D1 a : D1E env) shbinds
+ -> Ex (Append shbinds (D1 a : D1E env)) (D1 t)
+ -> Ex (Append shbinds (D1 a : D1E env)) tape
+ -> Sparse (D2 a) d2a'
+ -> (forall env' b.
+ D1E provars :> env'
+ -> Ex (Append (D2AcE provars) env') b
+ -> Ex ( env') (TPair b (Tup (D2E provars))))
+ -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
+ -> r)
+ -> r
+drevLambda des accumMap (argty, argsto) sd origef k =
+ let t = typeOf origef in
+ deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') ->
+ let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ case prf1 prodes argty argsto of { Refl ->
+ case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
+ extractContrib prodes argty argsto subEf $ \argSp getSparseArg ->
+ let library = #fbinds (bindingsBinds ef0)
+ &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
+ &. #ftape (auto1 @(Tape e_tape))
+ &. #arg (d1 argty `SCons` SNil)
+ &. #d (applySparse sd (d2 t) `SCons` SNil)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes)
+ &. #propr (d1e envPro)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace envPro)
+ &. #efPrerebinds efPrerebinds in
+ k envPro
+ (subenvD2E (subenvCompose subMergeUsed proSub))
+ mergePrimalBindings
+ (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0))
+ (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: #arg :++: #d1env))
+ ef1)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env)))
+ argSp
+ (\wpro1 body ->
+ uninvertTup (d2e envPro) (typeOf body) $
+ makeAccumulators wpro1 envPro $
+ body)
+ (letBinds (efRebinds IZ) $
+ weakenExpr
+ (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds ef0) subtapeEf))
+ (getSparseArg ef2))
+ }}
+ where
+ extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False)
+ => proxy env sto -> proxy2 a -> Storage s
+ -- if s == "merge", this simplifies to SubenvS '[D2 a] t'
+ -- if s == "discr", this simplifies to SubenvS '[] t'
+ -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t'
+ -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r
+ extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id
+ extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext)
+ extractContrib _ _ SDiscr SETop k' = k' SpAbsent id
+
+ prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s
+ -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum"
+ prf1 _ _ SMerge = Refl
+ prf1 _ _ SDiscr = Refl
+
+-- TODO: proper primal-only transform that doesn't depend on D1 = Id
+drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal des e
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
+ = mapExt (const ext) e
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Drev/Accum.hs
index a7bc53f..6f25f11 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Drev/Accum.hs
@@ -1,13 +1,13 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
-- | TODO this module is a grab-bag of random utility functions that are shared
--- between CHAD and CHAD.Top.
-module CHAD.Accum where
+-- between CHAD.Drev and CHAD.Drev.Top.
+module CHAD.Drev.Accum where
-import AST
-import CHAD.Types
-import Data
-import AST.Env
+import CHAD.AST
+import CHAD.Data
+import CHAD.Drev.Types
+import CHAD.AST.Env
d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs
index 49ae0e6..5a90303 100644
--- a/src/CHAD/EnvDescr.hs
+++ b/src/CHAD/Drev/EnvDescr.hs
@@ -7,18 +7,18 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module CHAD.EnvDescr where
+module CHAD.Drev.EnvDescr where
import Data.Kind (Type)
import Data.Some
import GHC.TypeLits (Symbol)
-import Analysis.Identity (ValId(..))
-import AST.Env
-import AST.Types
-import AST.Weaken
-import CHAD.Types
-import Data
+import CHAD.Analysis.Identity (ValId(..))
+import CHAD.AST.Env
+import CHAD.AST.Types
+import CHAD.AST.Weaken
+import CHAD.Data
+import CHAD.Drev.Types
type Storage :: Symbol -> Type
diff --git a/src/CHAD/Top.hs b/src/CHAD/Drev/Top.hs
index 4814bdf..510e73e 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Drev/Top.hs
@@ -8,20 +8,20 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module CHAD.Top where
+module CHAD.Drev.Top where
-import Analysis.Identity
-import AST
-import AST.Env
-import AST.Sparse
-import AST.SplitLets
-import AST.Weaken.Auto
-import CHAD
-import CHAD.Accum
-import CHAD.EnvDescr
-import CHAD.Types
-import Data
-import qualified Data.VarMap as VarMap
+import CHAD.Analysis.Identity
+import CHAD.AST
+import CHAD.AST.Env
+import CHAD.AST.Sparse
+import CHAD.AST.SplitLets
+import CHAD.AST.Weaken.Auto
+import CHAD.Data
+import qualified CHAD.Data.VarMap as VarMap
+import CHAD.Drev
+import CHAD.Drev.Accum
+import CHAD.Drev.EnvDescr
+import CHAD.Drev.Types
type family MergeEnv env where
diff --git a/src/CHAD/Types.hs b/src/CHAD/Drev/Types.hs
index 44ac20e..367a974 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Drev/Types.hs
@@ -2,11 +2,11 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module CHAD.Types where
+module CHAD.Drev.Types where
-import AST.Accum
-import AST.Types
-import Data
+import CHAD.AST.Accum
+import CHAD.AST.Types
+import CHAD.Data
type family D1 t where
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs
index 888fed4..019119c 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Drev/Types/ToTan.hs
@@ -1,14 +1,14 @@
{-# LANGUAGE GADTs #-}
-module CHAD.Types.ToTan where
+module CHAD.Drev.Types.ToTan where
import Data.Bifunctor (bimap)
-import Array
-import AST.Types
-import CHAD.Types
-import Data
-import ForwardAD
-import Interpreter.Rep
+import CHAD.Array
+import CHAD.AST.Types
+import CHAD.Data
+import CHAD.Drev.Types
+import CHAD.ForwardAD
+import CHAD.Interpreter.Rep
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
diff --git a/src/CHAD/Example.hs b/src/CHAD/Example.hs
new file mode 100644
index 0000000..884f99a
--- /dev/null
+++ b/src/CHAD/Example.hs
@@ -0,0 +1,197 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+
+{-# OPTIONS -Wno-unused-imports #-}
+module CHAD.Example where
+
+import Debug.Trace
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Pretty
+import CHAD.AST.UnMonoid
+import CHAD.Data
+import CHAD.Drev
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.Example.Types
+import CHAD.ForwardAD
+import CHAD.Interpreter
+import CHAD.Language
+import CHAD.Simplify
+
+
+-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
+
+
+pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+pipeline config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ simplifyFix $ pruneExpr knownEnv $
+ simplifyFix $ unMonoid $
+ simplifyFix $ chad' config knownEnv $
+ simplifyFix $ term
+
+-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures
+pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO ()
+pipeline' config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ pprintExpr (pipeline config term)
+
+
+bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
+bin op a b = EOp ext op (EPair ext a b)
+
+senv1 :: SList STy [TScal TF32, TScal TF32]
+senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
+
+-- x y |- x * y + x
+--
+-- let x3 = (x1, x2)
+-- x4 = ((*) x3, x1)
+-- in ( (+) x4
+-- , let x5 = 1.0
+-- x6 = Inr (x5, x5)
+-- in case x6 of
+-- Inl x7 -> return ()
+-- Inr x8 ->
+-- let x9 = fst x8
+-- x10 = Inr (snd x3 * x9, fst x3 * x9)
+-- in case x10 of
+-- Inl x11 -> return ()
+-- Inr x12 ->
+-- let x13 = fst x12
+-- in one "v1" x13 >>= \x14 ->
+-- let x15 = snd x12
+-- in one "v2" x15 >>= \x16 ->
+-- let x17 = snd x8
+-- in one "v1" x17)
+--
+-- ( (x1 * x2) + x1
+-- , let x5 = 1.0
+-- in do one "v1" (x2 * x5)
+-- one "v2" (x1 * x5)
+-- one "v1" x5)
+ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex1 = fromNamed $ lambda #x $ lambda #y $ body $
+ #x * #y + #x
+
+-- x y |- let z = x + y in z * (z + x)
+ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex2 = fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #z (#x + #y) $
+ #z * (#z + #x)
+
+-- x y |- if x < y then 2 * x else 3 + x
+ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex3 = fromNamed $ lambda #x $ lambda #y $ body $
+ if_ (#x .< #y) (2 * #x) (3 * #x)
+
+-- x y |- if x < y then 2 * x + y * y else 3 + x
+ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex4 = fromNamed $ lambda #x $ lambda #y $ body $
+ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
+
+-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
+ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
+ex5 = fromNamed $ lambda #x $ lambda #y $ body $
+ case_ #x (#a :-> #a * #y)
+ (#b :-> #b * (#y + 1))
+
+-- x:R n:I |- let a = unit x
+-- b = build1 n (\i. let c = idx0 a in c * c)
+-- in idx0 (b ! 3)
+ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
+ex6 = fromNamed $ lambda #x $ lambda #n $ body $
+ let_ #a (unit #x) $
+ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
+ #b ! pair nil 3
+
+-- A "neural network" except it's just scalars, not matrices.
+-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
+-- |- let p1 = snd ps
+-- p1' = fst ps
+-- x1 = fst p1 * x + snd p1
+-- p2 = snd p1'
+-- p2' = fst p1'
+-- x2 = fst p2 * x + snd p2
+-- p3 = snd p2'
+-- p3' = fst p2'
+-- x3 = fst p3 * x + snd p3
+-- in x3
+ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
+ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
+ let tR = STScal STF64
+ tpair = STPair tR tR
+
+ layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
+ => STy p -> NExpr env R
+ layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
+ let_ #par (snd_ #parstup) $
+ let_ #restpars (fst_ #parstup) $
+ let_ #inp (fst_ #par * #inp + snd_ #par) $
+ let_ #parstup #restpars $
+ layer t
+ layer STNil = #inp
+ layer _ = error "Invalid layer inputs"
+
+ in let_ #parstup #pars123 $
+ let_ #inp #input $
+ layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
+
+neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
+neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
+ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
+ -- prod = wei `matmul` x
+ let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
+ #wei ! #idx * #x ! pair nil (snd_ #idx)) $
+ -- relu (prod + bias)
+ build (SS SZ) (shape #prod) $ #idx :->
+ let_ #out (#prod ! #idx + #bias ! #idx) $
+ if_ (#out .<= const_ 0) (const_ 0) #out
+
+ in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
+ let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
+ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
+ #x3 ! nil
+
+type NeuralGrad = ((Array N2 Double, Array N1 Double)
+ ,(Array N2 Double, Array N1 Double)
+ ,Array N1 Double
+ ,Array N1 Double)
+
+neuralGo :: (Double -- primal
+ ,NeuralGrad -- gradient using CHAD
+ ,NeuralGrad) -- gradient using dual-numbers forward AD
+neuralGo =
+ let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
+ input = arrayFromList (ShNil `ShCons` 2) [1,1]
+ argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
+ revderiv =
+ simplifyN 20 $
+ ELet ext (EConst ext STF64 1.0) $
+ chad defaultConfig knownEnv neural
+ (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
+ (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
+ (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
+ in trace (ppExpr knownEnv revderiv) $
+ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
+
+-- The build body uses free variables in a non-linear way, so their primal
+-- values are required in the dual of the build. Thus, compositionally, they
+-- are stored in the tape from each individual lambda invocation. This results
+-- in n copies of y and z, where only one copy would have sufficed.
+exUniformFree :: Ex '[R, I64] R
+exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $
+ let_ #y (#x * 2) $
+ let_ #z (#x * 3) $
+ idx0 $ sum1i $
+ build1 #n $ #i :-> #y * #z + toFloat_ #i
diff --git a/src/CHAD/Example/GMM.hs b/src/CHAD/Example/GMM.hs
new file mode 100644
index 0000000..8f834e0
--- /dev/null
+++ b/src/CHAD/Example/GMM.hs
@@ -0,0 +1,124 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE TypeApplications #-}
+module CHAD.Example.GMM where
+
+import CHAD.Data (SList(..), SNat(..))
+import CHAD.Example.Types
+import CHAD.Language
+
+
+
+-- N, D, K: integers > 0
+-- alpha, M, Q, L: the active parameters
+-- X: inactive data
+-- m: integer
+-- k1: 1/2 N D log(2 pi)
+-- k2: 1/2 gamma^2
+-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D))
+-- where n' = D + m + 1
+--
+-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m.
+--
+-- See:
+-- - "A benchmark of selected algorithmic differentiation tools on some problems
+-- in computer vision and machine learning". Optim. Methods Softw. 33(4-6):
+-- 889-906 (2018).
+-- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651>
+-- <https://github.com/microsoft/ADBench>
+-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”.
+-- Master thesis at Utrecht University. (Appendix B.1)
+-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
+-- <https://tomsmeding.com/f/master.pdf>
+--
+-- The 'wrong' argument, when set to True, changes the objective function to
+-- one with a bug that makes a certain `build` result unused. This
+-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
+-- dense, even though it may be a zero (i.e. empty). The "unused" test in
+-- test/Main.hs tries to isolate this case, but the wrong version of
+-- gmmObjective is here to check (after that bug is fixed) whether it really
+-- fixes the original bug.
+gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
+gmmObjective wrong = fromNamed $
+ lambda #N $ lambda #D $ lambda #K $
+ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $
+ lambda #X $ lambda #m $
+ lambda #k1 $ lambda #k2 $ lambda #k3 $
+ body $
+ let -- We have:
+ -- sum (exp (x - max(x)))
+ -- = sum (exp x / exp (max(x)))
+ -- = sum (exp x) / exp (max(x))
+ -- Hence:
+ -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*)
+ --
+ -- So:
+ -- d/dxi log (sum (exp x))
+ -- = 1/(sum (exp x)) * d/dxi sum (exp x)
+ -- = 1/(sum (exp x)) * sum (d/dxi exp x)
+ -- = 1/(sum (exp x)) * exp xi
+ -- = exp xi / sum (exp x)
+ -- (by (*))
+ -- = exp xi / (sum (exp (x - max(x))) * exp (max(x)))
+ -- = exp (xi - max(x)) / sum (exp (x - max(x)))
+ logsumexp' = lambda @(TVec R) #vec $ body $
+ let_ #m (maximum1i #vec) $
+ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
+ -- custom (#_ :-> #v :->
+ -- let_ #m (idx0 (maximum1i #v)) $
+ -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m)
+ -- (#_ :-> #v :->
+ -- let_ #m (idx0 (maximum1i #v)) $
+ -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $
+ -- let_ #s (idx0 (sum1i #ex)) $
+ -- pair (log #s + #m)
+ -- (pair #ex #s))
+ -- (#tape :-> #d :->
+ -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape))
+ -- nil #vec
+ logsumexp v = inline logsumexp' (SNil .$ v)
+
+ mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $
+ let_ #hei (snd_ (fst_ (shape #mat))) $
+ let_ #wid (snd_ (shape #mat)) $
+ build1 #hei $ #i :->
+ idx0 (sum1i (build1 #wid $ #j :->
+ #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
+ m *@ v = inline mulmatvec (SNil .$ m .$ v)
+
+ subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $
+ build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i
+ a .- b = inline subvec (SNil .$ a .$ b)
+
+ matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $
+ build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j)
+ m .! i = inline matrow (SNil .$ m .$ i)
+
+ normsq' = lambda @(TVec R) #vec $ body $
+ idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x)))
+ normsq v = inline normsq' (SNil .$ v)
+
+ qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $
+ let_ #n (snd_ (shape #q)) $
+ build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :->
+ let_ #i (snd_ (fst_ #idx)) $
+ let_ #j (snd_ #idx) $
+ if_ (#i .== #j)
+ (exp (#q ! pair nil #i))
+ (if_ (#i .> #j)
+ (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j)
+ else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j))
+ 0.0)
+ qmat q l = inline qmat' (SNil .$ q .$ l)
+ in let_ #k2arr (unit #k2) $
+ - #k1
+ + idx0 (sum1i (build1 #N $ #i :->
+ logsumexp (build1 #K $ #k :->
+ #alpha ! pair nil #k
+ + idx0 (sum1i (#Q .! #k))
+ - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k))))))
+ - toFloat_ #N * logsumexp #alpha
+ + idx0 (sum1i (build1 #K $ #k :->
+ idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k))
+ - toFloat_ #m * idx0 (sum1i (#Q .! #k))))
+ - #k3
diff --git a/src/CHAD/Example/Types.hs b/src/CHAD/Example/Types.hs
new file mode 100644
index 0000000..1e2f72d
--- /dev/null
+++ b/src/CHAD/Example/Types.hs
@@ -0,0 +1,11 @@
+{-# LANGUAGE DataKinds #-}
+module CHAD.Example.Types where
+
+import CHAD.AST
+import CHAD.Data
+
+
+type R = TScal TF64
+type I64 = TScal TI64
+type TVec = TArr (S Z)
+type TMat = TArr (S (S Z))
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs
new file mode 100644
index 0000000..7126e10
--- /dev/null
+++ b/src/CHAD/ForwardAD.hs
@@ -0,0 +1,270 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.ForwardAD where
+
+import Data.Bifunctor (bimap)
+import System.IO.Unsafe
+
+-- import Debug.Trace
+-- import CHAD.AST.Pretty
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.Compile
+import CHAD.Data
+import CHAD.ForwardAD.DualNumbers
+import CHAD.Interpreter
+import CHAD.Interpreter.Rep
+
+
+-- | Tangent along a type (coincides with cotangent for these types)
+type family Tan t where
+ Tan TNil = TNil
+ Tan (TPair a b) = TPair (Tan a) (Tan b)
+ Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
+ Tan (TMaybe t) = TMaybe (Tan t)
+ Tan (TArr n t) = TArr n (Tan t)
+ Tan (TScal t) = TanS t
+
+type family TanS t where
+ TanS TI32 = TNil
+ TanS TI64 = TNil
+ TanS TF32 = TScal TF32
+ TanS TF64 = TScal TF64
+ TanS TBool = TNil
+
+type family TanE env where
+ TanE '[] = '[]
+ TanE (t : env) = Tan t : TanE env
+
+tanty :: STy t -> STy (Tan t)
+tanty STNil = STNil
+tanty (STPair a b) = STPair (tanty a) (tanty b)
+tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
+tanty (STMaybe t) = STMaybe (tanty t)
+tanty (STArr n t) = STArr n (tanty t)
+tanty (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+tanty STAccum{} = error "Accumulators not allowed in input program"
+
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
+
+zeroTan :: STy t -> Rep t -> Rep (Tan t)
+zeroTan STNil () = ()
+zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
+zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
+zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
+zeroTan (STMaybe _) Nothing = Nothing
+zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
+zeroTan (STArr _ t) x = fmap (zeroTan t) x
+zeroTan (STScal STI32) _ = ()
+zeroTan (STScal STI64) _ = ()
+zeroTan (STScal STF32) _ = 0.0
+zeroTan (STScal STF64) _ = 0.0
+zeroTan (STScal STBool) _ = ()
+zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+
+tanScalars :: STy t -> Rep (Tan t) -> [Double]
+tanScalars STNil () = []
+tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
+tanScalars (STEither a _) (Left x) = tanScalars a x
+tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
+tanScalars (STMaybe _) Nothing = []
+tanScalars (STMaybe t) (Just x) = tanScalars t x
+tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
+tanScalars (STScal STI32) _ = []
+tanScalars (STScal STI64) _ = []
+tanScalars (STScal STF32) x = [realToFrac x]
+tanScalars (STScal STF64) x = [x]
+tanScalars (STScal STBool) _ = []
+tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+
+tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
+tanEScalars SNil SNil = []
+tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs
+
+unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
+unzipDN STNil _ = ((), ())
+unzipDN (STPair a b) (d1, d2) =
+ let (x, dx) = unzipDN a d1
+ (y, dy) = unzipDN b d2
+ in ((x, y), (dx, dy))
+unzipDN (STEither a b) d = case d of
+ Left d1 -> bimap Left Left (unzipDN a d1)
+ Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
+unzipDN (STMaybe t) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just d' -> bimap Just Just (unzipDN t d')
+unzipDN (STArr _ t) d =
+ let pairs = arrayMap (unzipDN t) d
+ in (arrayMap fst pairs, arrayMap snd pairs)
+unzipDN (STScal ty) d = case ty of
+ STI32 -> (d, ())
+ STI64 -> (d, ())
+ STF32 -> d
+ STF64 -> d
+ STBool -> (d, ())
+unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+
+dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
+dotprodTan STNil _ _ = 0.0
+dotprodTan (STPair a b) (x, y) (x', y') =
+ dotprodTan a x x' + dotprodTan b y y'
+dotprodTan (STEither a b) x y = case (x, y) of
+ (Left x', Left y') -> dotprodTan a x' y'
+ (Right x', Right y') -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
+dotprodTan (STMaybe t) x y = case (x, y) of
+ (Nothing, Nothing) -> 0.0
+ (Just x', Just y') -> dotprodTan t x' y'
+ _ -> error "dotprodTan: incompatible Maybe alternatives"
+dotprodTan (STArr _ t) x y =
+ let sh1 = arrayShape x
+ sh2 = arrayShape y
+ in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
+ | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
+ | otherwise -> error "dotprodTan: incompatible array shapes"
+dotprodTan (STScal ty) x y = case ty of
+ STI32 -> 0.0
+ STI64 -> 0.0
+ STF32 -> realToFrac @Float @Double (x * y)
+ STF64 -> x * y
+ STBool -> 0.0
+dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+
+-- -- Primal expression must be duplicable
+-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
+-- dnConstE STNil _ = ENil ext
+-- dnConstE (STPair t1 t2) e =
+-- -- This creates fst/snd stacks of unbounded size, but let's not care here
+-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
+-- dnConstE (STEither t1 t2) e =
+-- ECase ext e
+-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
+-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
+-- dnConstE (STMaybe t) e =
+-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
+-- dnConstE (STArr n t) e =
+-- EBuild ext n (EShape ext e)
+-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
+-- dnConstE (STScal t) e = case t of
+-- STI32 -> e
+-- STI64 -> e
+-- STF32 -> EPair ext e (EConst ext STF32 0.0)
+-- STF64 -> EPair ext e (EConst ext STF64 0.0)
+-- STBool -> e
+-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConst :: STy t -> Rep t -> Rep (DN t)
+dnConst STNil = const ()
+dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
+dnConst (STMaybe t) = fmap (dnConst t)
+dnConst (STArr _ t) = arrayMap (dnConst t)
+dnConst (STScal t) = case t of
+ STI32 -> id
+ STI64 -> id
+ STF32 -> (,0.0)
+ STF64 -> (,0.0)
+ STBool -> id
+dnConst STAccum{} = error "Accumulators not allowed in input program"
+
+-- | Given a function that computes the forward derivative for a particular
+-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
+-- @t@ input.
+type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)
+
+dnOnehots :: STy t -> Rep t -> RevByFwd t
+dnOnehots STNil _ = \_ -> ()
+dnOnehots (STPair t1 t2) (x, y) =
+ \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
+dnOnehots (STEither t1 t2) e =
+ case e of
+ Left x -> \f -> Left (dnOnehots t1 x (f . Left))
+ Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
+dnOnehots (STMaybe t) m =
+ case m of
+ Nothing -> \_ -> Nothing
+ Just x -> \f -> Just (dnOnehots t x (f . Just))
+dnOnehots (STArr _ t) a =
+ \f ->
+ arrayGenerate (arrayShape a) $ \idx ->
+ dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
+ if i == idx then oh else dnConst t (arrayIndex a i)))
+dnOnehots (STScal t) x = case t of
+ STI32 -> \_ -> ()
+ STI64 -> \_ -> ()
+ STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
+ STF64 -> \f -> f (x, 1.0)
+ STBool -> \_ -> ()
+dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
+dnConstEnv SNil SNil = SNil
+dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val
+
+type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)
+
+dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
+dnOnehotEnvs SNil SNil = \_ -> SNil
+dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
+ \f ->
+ Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
+ `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
+
+data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))
+
+makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
+makeFwdADArtifactInterp env expr =
+ let dexpr = dfwdDN expr
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
+
+{-# NOINLINE makeFwdADArtifactCompile #-}
+makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String)
+makeFwdADArtifactCompile env expr = do
+ (fun, output) <- compile (dne env) (dfwdDN expr)
+ return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output)
+
+drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
+
+drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd (FwdADArtifact env outty fun) input dres =
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
+ let (_, outtan) = unzipDN outty (fun dnInput)
+ in dotprodTan outty outtan dres
diff --git a/src/CHAD/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs
new file mode 100644
index 0000000..a71efc8
--- /dev/null
+++ b/src/CHAD/ForwardAD/DualNumbers.hs
@@ -0,0 +1,231 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module CHAD.ForwardAD.DualNumbers (
+ dfwdDN,
+ DN, DNS, DNE, dn, dne,
+) where
+
+import CHAD.AST
+import CHAD.Data
+import CHAD.ForwardAD.DualNumbers.Types
+
+
+dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
+dnPreservesTupIx SZ = Refl
+dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
+
+convIdx :: Idx env t -> Idx (DNE env) (DN t)
+convIdx IZ = IZ
+convIdx (IS i) = IS (convIdx i)
+
+scalTyCase :: SScalTy t
+ -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
+ -> (DN (TScal t) ~ TScal t => r)
+ -> r
+scalTyCase STF32 k1 _ = k1
+scalTyCase STF64 k1 _ = k1
+scalTyCase STI32 _ k2 = k2
+scalTyCase STI64 _ k2 = k2
+scalTyCase STBool _ k2 = k2
+
+floatingDual :: ScalIsFloating t ~ True
+ => SScalTy t
+ -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r
+floatingDual STF32 k = k
+floatingDual STF64 k = k
+
+-- | Argument does not need to be duplicable.
+dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
+dop = \case
+ OAdd t -> scalTyCase t
+ (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
+ (EOp ext (OAdd t))
+ OMul t -> scalTyCase t
+ (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
+ (EOp ext (OMul t))
+ ONeg t -> scalTyCase t
+ (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
+ (EOp ext (ONeg t))
+ OLt t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
+ (EOp ext (OLt t))
+ OLe t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
+ (EOp ext (OLe t))
+ OEq t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
+ (EOp ext (OEq t))
+ ONot -> EOp ext ONot
+ OAnd -> EOp ext OAnd
+ OOr -> EOp ext OOr
+ OIf -> EOp ext OIf
+ ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
+ OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
+ ORecip t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (recip' t x)
+ (mul t (neg t (recip' t (mul t x x))) dx))
+ OExp t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx))
+ OLog t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (EOp ext (OLog t) x)
+ (mul t (recip' t x) dx))
+ OIDiv t -> scalTyCase t
+ (case t of {})
+ (EOp ext (OIDiv t))
+ OMod t -> scalTyCase t
+ (case t of {})
+ (EOp ext (OMod t))
+ where
+ add :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
+ add t a b = EOp ext (OAdd t) (EPair ext a b)
+
+ mul :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
+ mul t a b = EOp ext (OMul t) (EPair ext a b)
+
+ neg :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
+ neg t = EOp ext (ONeg t)
+
+ recip' :: ScalIsFloating t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
+ recip' t = EOp ext (ORecip t)
+
+ unFloat :: DN a ~ TPair a a
+ => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
+ -> Ex env (DN a) -> Ex env (DN b)
+ unFloat f e =
+ ELet ext e $
+ let var = EVar ext (typeOf e) IZ
+ in f (EFst ext var, ESnd ext var)
+
+ binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
+ => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
+ -> Ex env (DN a) -> Ex env (DN b)
+ binFloat f e =
+ ELet ext e $
+ let var = EVar ext (typeOf e) IZ
+ in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
+ (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
+
+zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
+zeroScalarConst STI32 = EConst ext STI32 0
+zeroScalarConst STI64 = EConst ext STI64 0
+zeroScalarConst STF32 = EConst ext STF32 0.0
+zeroScalarConst STF64 = EConst ext STF64 0.0
+
+dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
+dfwdDN = \case
+ EVar _ t i -> EVar ext (dn t) (convIdx i)
+ ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
+ EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
+ EFst _ e -> EFst ext (dfwdDN e)
+ ESnd _ e -> ESnd ext (dfwdDN e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext (dn t) (dfwdDN e)
+ EInr _ t e -> EInr ext (dn t) (dfwdDN e)
+ ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ ENothing _ t -> ENothing ext (dn t)
+ EJust _ e -> EJust ext (dfwdDN e)
+ EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2)
+ ELInl _ t e -> ELInl ext (dn t) (dfwdDN e)
+ ELInr _ t e -> ELInr ext (dn t) (dfwdDN e)
+ ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c)
+ EConstArr _ n t x -> scalTyCase t
+ (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
+ (EConstArr ext n t x))
+ (EConstArr ext n t x)
+ EBuild _ n a b
+ | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
+ EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)
+ ESum1Inner _ e ->
+ let STArr n (STScal t) = typeOf e
+ pairty = (STPair (STScal t) (STScal t))
+ in scalTyCase t
+ (ELet ext (dfwdDN e) $
+ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
+ (EVar ext (STArr n pairty) IZ)))
+ (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
+ (EVar ext (STArr n pairty) IZ))))
+ (ESum1Inner ext (dfwdDN e))
+ EUnit _ e -> EUnit ext (dfwdDN e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
+ EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b)
+ EReshape _ n esh e
+ | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)
+ EConst _ t x -> scalTyCase t
+ (EPair ext (EConst ext t x) (EConst ext t 0.0))
+ (EConst ext t x)
+ EIdx0 _ e -> EIdx0 ext (dfwdDN e)
+ EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
+ EIdx _ a b
+ | STArr n _ <- typeOf a
+ , Refl <- dnPreservesTupIx n
+ -> EIdx ext (dfwdDN a) (dfwdDN b)
+ EShape _ e
+ | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
+ -> EShape ext (dfwdDN e)
+ EOp _ op e -> dop op (dfwdDN e)
+ ECustom _ _ _ _ pr _ _ e1 e2 ->
+ ELet ext (dfwdDN e1) $
+ ELet ext (weakenExpr WSink (dfwdDN e2)) $
+ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
+ ERecompute _ e -> dfwdDN e
+ EError _ t s -> EError ext (dn t) s
+
+ EWith{} -> err_accum
+ EAccum{} -> err_accum
+ EDeepZero{} -> err_monoid
+ EZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
+
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+ where
+ err_accum = error "Accumulator operations unsupported in the source program"
+ err_monoid = error "Monoid operations unsupported in the source program"
+ err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program"
+
+ deriv_extremum :: ScalIsNumeric t ~ True
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t)))
+ deriv_extremum extremum e =
+ let STArr (SS n) (STScal t) = typeOf e
+ t2 = STPair (STScal t) (STScal t)
+ ta2 = STArr (SS n) t2
+ tIxN = tTup (sreplicate (SS n) tIx)
+ in scalTyCase t
+ (ELet ext (dfwdDN e) $
+ ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $
+ ezip (EVar ext (STArr n (STScal t)) IZ)
+ (ESum1Inner ext
+ {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -}
+ (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $
+ ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext
+ (EFst ext (EVar ext t2 IZ))
+ (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ)))
+ (EFst ext (EVar ext tIxN (IS IZ)))))))
+ (ESnd ext (EVar ext t2 (IS IZ)))
+ (zeroScalarConst t))))
+ (extremum (dfwdDN e))
diff --git a/src/CHAD/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs
new file mode 100644
index 0000000..5d5dd9e
--- /dev/null
+++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs
@@ -0,0 +1,48 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.ForwardAD.DualNumbers.Types where
+
+import CHAD.AST.Types
+import CHAD.Data
+
+
+-- | Dual-numbers transformation
+type family DN t where
+ DN TNil = TNil
+ DN (TPair a b) = TPair (DN a) (DN b)
+ DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TLEither a b) = TLEither (DN a) (DN b)
+ DN (TMaybe t) = TMaybe (DN t)
+ DN (TArr n t) = TArr n (DN t)
+ DN (TScal t) = DNS t
+
+type family DNS t where
+ DNS TF32 = TPair (TScal TF32) (TScal TF32)
+ DNS TF64 = TPair (TScal TF64) (TScal TF64)
+ DNS TI32 = TScal TI32
+ DNS TI64 = TScal TI64
+ DNS TBool = TScal TBool
+
+type family DNE env where
+ DNE '[] = '[]
+ DNE (t : ts) = DN t : DNE ts
+
+dn :: STy t -> STy (DN t)
+dn STNil = STNil
+dn (STPair a b) = STPair (dn a) (dn b)
+dn (STEither a b) = STEither (dn a) (dn b)
+dn (STLEither a b) = STLEither (dn a) (dn b)
+dn (STMaybe t) = STMaybe (dn t)
+dn (STArr n t) = STArr n (dn t)
+dn (STScal t) = case t of
+ STF32 -> STPair (STScal STF32) (STScal STF32)
+ STF64 -> STPair (STScal STF64) (STScal STF64)
+ STI32 -> STScal STI32
+ STI64 -> STScal STI64
+ STBool -> STScal STBool
+dn STAccum{} = error "Accum in source program"
+
+dne :: SList STy env -> SList STy (DNE env)
+dne SNil = SNil
+dne (t `SCons` env) = dn t `SCons` dne env
diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs
new file mode 100644
index 0000000..a9421e6
--- /dev/null
+++ b/src/CHAD/Interpreter.hs
@@ -0,0 +1,471 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Interpreter (
+ interpret,
+ interpretOpen,
+ Value(..),
+) where
+
+import Control.Monad (foldM, join, when, forM_)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict (runStateT, get, put)
+import Data.Bifunctor (bimap)
+import Data.Bitraversable (bitraverse)
+import Data.Char (isSpace)
+import Data.Functor.Identity
+import qualified Data.Functor.Product as Product
+import Data.Int (Int64)
+import Data.IORef
+import Data.Tuple (swap)
+import System.IO (hPutStrLn, stderr)
+import System.IO.Unsafe (unsafePerformIO)
+
+import Debug.Trace
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Pretty
+import CHAD.AST.Sparse.Types
+import CHAD.Data
+import CHAD.Interpreter.Rep
+
+
+newtype AcM s a = AcM { unAcM :: IO a }
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
+
+acmDebugLog :: String -> AcM s ()
+acmDebugLog s = AcM (hPutStrLn stderr s)
+
+data V t = V (STy t) (Rep t)
+
+interpret :: Ex '[] t -> Rep t
+interpret = interpretOpen False SNil SNil
+
+-- | Bool: whether to trace execution with debug prints (very verbose)
+interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env venv e =
+ runAcM $
+ let ?depth = 0
+ ?prints = prints
+ in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e
+
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int)
+ => SList V env -> Ex env t -> AcM s (Rep t)
+interpret' env e = do
+ let tenv = slistMap (\(V t _) -> t) env
+ let dep = ?depth
+ let lenlimit = max 20 (100 - dep)
+ let replace a b = map (\c -> if c == a then b else c)
+ let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
+ | otherwise = replace '\n' ' ' s
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e)
+ res <- let ?depth = dep + 1 in interpret'Rec env e
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
+ return res
+
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t)
+interpret'Rec env = \case
+ EVar _ _ i -> case slistIdx env i of V _ x -> return x
+ ELet _ a b -> do
+ x <- interpret' env a
+ let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b
+ expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
+ EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
+ EFst _ e -> fst <$> interpret' env e
+ ESnd _ e -> snd <$> interpret' env e
+ ENil _ -> return ()
+ EInl _ _ e -> Left <$> interpret' env e
+ EInr _ _ e -> Right <$> interpret' env e
+ ECase _ e a b ->
+ let STEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Left x -> interpret' (V t1 x `SCons` env) a
+ Right y -> interpret' (V t2 y `SCons` env) b
+ ENothing _ _ -> return Nothing
+ EJust _ e -> Just <$> interpret' env e
+ EMaybe _ a b e ->
+ let STMaybe t1 = typeOf e
+ in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
+ ELNil _ _ _ -> return Nothing
+ ELInl _ _ e -> Just . Left <$> interpret' env e
+ ELInr _ _ e -> Just . Right <$> interpret' env e
+ ELCase _ e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Nothing -> interpret' env a
+ Just (Left x) -> interpret' (V t1 x `SCons` env) b
+ Just (Right y) -> interpret' (V t2 y `SCons` env) c
+ EConstArr _ _ _ v -> return v
+ EBuild _ dim a b -> do
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
+ arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ EMap _ a b -> do
+ let STArr _ t = typeOf b
+ arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b
+ EFold1Inner _ _ a b c -> do
+ let t = typeOf b
+ let f = \x -> interpret' (V (STPair t t) x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ ESum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
+ EReplicate1Inner _ a b -> do
+ n <- fromIntegral @Int64 @Int <$> interpret' env a
+ arr <- interpret' env b
+ let sh = arrayShape arr
+ return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx)
+ EMaximum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $
+ arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EMinimum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $
+ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EReshape _ dim esh e -> do
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh
+ arr <- interpret' env e
+ return $ arrayReshape sh arr
+ EZip _ a b -> do
+ arr1 <- interpret' env a
+ arr2 <- interpret' env b
+ let sh = arrayShape arr1
+ when (sh /= arrayShape arr2) $
+ error "Interpreter: mismatched shapes in EZip"
+ return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i))
+ EFold1InnerD1 _ _ a b c -> do
+ let t = typeOf b
+ let f = \x -> interpret' (V (STPair t t) x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ -- TODO: this is very inefficient, even for an interpreter; with mutable
+ -- arrays this can be a lot better with no lists
+ res <- arrayGenerateM sh $ \idx -> do
+ (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ return (y, arrayFromList (ShNil `ShCons` n) stores)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EFold1InnerD2 _ _ ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef
+ bog <- interpret' env ebog
+ arrctg <- interpret' env ed
+ let sh `ShCons` n = arrayShape bog
+ when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2"
+ res <- arrayGenerateM sh $ \idx -> do
+ let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
+ loop i !ctg !inpctgs = do
+ let b = arrayIndex bog (idx `IxCons` i)
+ (ctg1, ctg2) <- f b ctg
+ loop (i - 1) ctg1 (ctg2 : inpctgs)
+ (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
+ return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EConst _ _ v -> return v
+ EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
+ EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
+ EIdx _ a b ->
+ let STArr n _ = typeOf a
+ in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
+ EOp _ op e -> interpretOp op <$> interpret' env e
+ ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
+ e1' <- interpret' env e1
+ e2' <- interpret' env e2
+ interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
+ ERecompute _ e -> interpret' env e
+ EWith _ t e1 e2 -> do
+ initval <- interpret' env e1
+ withAccum t (typeOf e2) initval $ \accum ->
+ interpret' (V (STAccum t) accum `SCons` env) e2
+ EAccum _ t p e1 sp e2 e3 -> do
+ idx <- interpret' env e1
+ val <- interpret' env e2
+ accum <- interpret' env e3
+ accumAddSparseD t p accum idx sp val
+ EZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
+ EPlus _ t a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ addM t a' b'
+ EOneHot _ t p a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ onehotM p t a' b'
+ EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
+
+interpretOp :: SOp a t -> Rep a -> Rep t
+interpretOp op arg = case op of
+ OAdd st -> numericIsNum st $ uncurry (+) arg
+ OMul st -> numericIsNum st $ uncurry (*) arg
+ ONeg st -> numericIsNum st $ negate arg
+ OLt st -> numericIsNum st $ uncurry (<) arg
+ OLe st -> numericIsNum st $ uncurry (<=) arg
+ OEq st -> styIsEq st $ uncurry (==) arg
+ ONot -> not arg
+ OAnd -> uncurry (&&) arg
+ OOr -> uncurry (||) arg
+ OIf -> if arg then Left () else Right ()
+ ORound64 -> round arg
+ OToFl64 -> fromIntegral arg
+ ORecip st -> floatingIsFractional st $ recip arg
+ OExp st -> floatingIsFractional st $ exp arg
+ OLog st -> floatingIsFractional st $ log arg
+ OIDiv st -> integralIsIntegral st $ uncurry quot arg
+ OMod st -> integralIsIntegral st $ uncurry rem arg
+ where
+ styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
+ styIsEq STI32 = id
+ styIsEq STI64 = id
+ styIsEq STF32 = id
+ styIsEq STF64 = id
+ styIsEq STBool = id
+
+zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t
+zeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi))
+ SMTLEither _ _ -> Nothing
+ SMTMaybe _ -> Nothing
+ SMTArr _ t -> arrayMap (zeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
+addM :: SMTy t -> Rep t -> Rep t -> Rep t
+addM typ a b = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b))
+ SMTLEither t1 t2 -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y))
+ (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y))
+ _ -> error "Plus of inconsistent LEithers"
+ SMTMaybe t -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just x, Just y) -> Just (addM t x y)
+ SMTArr _ t ->
+ let sh1 = arrayShape a
+ sh2 = arrayShape b
+ in if | shapeSize sh1 == 0 -> b
+ | shapeSize sh2 == 0 -> a
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
+ SMTScal sty -> numericIsNum sty $ a + b
+
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
+onehotM SAPHere _ _ val = val
+onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
+onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
+onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val))
+onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val))
+onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val)
+onehotM (SAPArrIdx prj) (SMTArr n a) idx val =
+ runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx
+
+withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+withAccum t _ initval f = AcM $ do
+ accum <- newAcDense t initval
+ out <- unAcM $ f accum
+ val <- readAc t accum
+ return (out, val)
+
+newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
+newAcDense typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val
+ SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val
+ SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val
+ SMTArr _ t1 -> arrayMapM (newAcDense t1) val
+ SMTScal _ -> newIORef val
+
+onehotArray :: Monad m
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
+ -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
+onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
+ let arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ziarr
+ !linindex = toLinearIndex arrsh arrindex
+ in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
+
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
+ SMTScal _ -> readIORef val
+
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
+ (SMTLEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
+
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
+
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let (arrindex', idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ref
+ linindex = toLinearIndex arrsh arrindex
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ case val of
+ Nothing -> return ()
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+-- TODO: makeval is always 'error' now. Simplify?
+realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
+realiseMaybeSparse ref makeval modifyval =
+ -- Try modifying what's already in ref. The 'join' makes the snd
+ -- of the function's return value a _continuation_ that is run after
+ -- the critical section ends.
+ AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (ac, do val <- makeval
+ join $ atomicModifyIORef' ref $ \ac' -> case ac' of
+ Nothing -> (Just val, return ())
+ Just val' -> (ac', unAcM $ modifyval val'))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just val -> (ac, unAcM $ modifyval val)
+
+
+numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
+numericIsNum STI32 = id
+numericIsNum STI64 = id
+numericIsNum STF32 = id
+numericIsNum STF64 = id
+
+floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r
+floatingIsFractional STF32 = id
+floatingIsFractional STF64 = id
+
+integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r
+integralIsIntegral STI32 = id
+integralIsIntegral STI64 = id
+
+unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
+ -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
+unTupRepIdx nil _ SZ _ = nil
+unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
+
+tupRepIdx :: (forall m. f (S m) -> (f m, Int))
+ -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
+tupRepIdx _ SZ _ = ()
+tupRepIdx uncons (SS n) tup =
+ let (tup', i) = uncons tup
+ in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i
+
+ixUncons :: Index (S n) -> (Index n, Int)
+ixUncons (IxCons idx i) = (idx, i)
+
+shUncons :: Shape (S n) -> (Shape n, Int)
+shUncons (ShCons idx i) = (idx, i)
+
+mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
+mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f'
+ where f' x = do
+ s <- get
+ (s', y) <- lift $ f s x
+ put s'
+ return y
diff --git a/src/CHAD/Interpreter/Accum.hs b/src/CHAD/Interpreter/Accum.hs
new file mode 100644
index 0000000..8e5c040
--- /dev/null
+++ b/src/CHAD/Interpreter/Accum.hs
@@ -0,0 +1,366 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+module CHAD.Interpreter.Accum (
+ AcM,
+ runAcM,
+ Rep',
+ Accum,
+ withAccum,
+ accumAdd,
+ inParallel,
+) where
+
+import Control.Concurrent
+import Control.Monad (when, forM_)
+import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
+import Foreign.Storable (sizeOf)
+import GHC.Exts
+import GHC.Float
+import GHC.Int
+import GHC.IO (IO(..))
+import GHC.Word
+import System.IO.Unsafe (unsafePerformIO)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.Data
+
+
+newtype AcM s a = AcM (IO a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
+
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TMaybe t) = Maybe (Rep' s t)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
+
+-- | Floats and integers are accumulated; booleans are left as-is.
+data Accum s t = Accum (STy t) (ForeignPtr ())
+
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
+
+tSize' :: Proxy s -> STy t -> Int
+tSize' p typ = case typ of
+ STNil -> 0
+ STPair a b -> tSize' p a + tSize' p b
+ STEither a b -> 1 + max (tSize' p a) (tSize' p b)
+ -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
+ STMaybe t -> tSize' p (STEither STNil t)
+ STArr ndim t ->
+ case val of
+ Nothing -> error "Nested arrays not supported in this implementation"
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
+ STScal sty -> goScal sty
+ STAccum{} -> error "Nested accumulators unsupported"
+ where
+ goScal :: SScalTy t -> Int
+ goScal STI32 = 4
+ goScal STI64 = 8
+ goScal STF32 = 4
+ goScal STF64 = 8
+ goScal STBool = 1
+
+-- | This operation does not commute with 'accumAdd', so it must be used with
+-- care. Furthermore it must be used on exactly the same value as tSize was
+-- called on. Hence it lives in IO, not in AcM.
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
+accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ go inarr ty val off = case ty of
+ STNil -> return off
+ STPair a b -> do
+ off1 <- go inarr a (fst val) off
+ go inarr b (snd val) off1
+ STEither a b -> do
+ let !(I# off#) = off
+ off1 <- case val of
+ Left x -> do
+ let !(I8# tag#) = 0
+ writeInt8# addr# off# tag#
+ go inarr a x (off + 1)
+ Right y -> do
+ let !(I8# tag#) = 1
+ writeInt8# addr# off# tag#
+ go inarr b y (off + 1)
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
+ else return off1
+ -- Representation is the same, but add operation is different
+ STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
+ STArr _ t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ off1 <- goShape (arrayShape val) off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = arraySize val
+ traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
+ return (off1 + eltsize * n)
+ STScal sty -> goScal sty val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goShape :: Shape n -> Int -> IO Int
+ goShape ShNil off = return off
+ goShape (ShCons sh n) off = do
+ off1@(I# off1#) <- goShape sh off
+ let !(I64# n'#) = fromIntegral n
+ writeInt64# addr# off1# n'#
+ return (off1 + 8)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
+ goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
+ goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
+ goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
+ goScal STBool b off@(I# off#) = do
+ let !(I8# i) = fromIntegral (fromEnum b)
+ off + 1 <$ writeInt8# addr# off# i
+
+ in () <$ go False topty top_value 0
+
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
+accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
+ go inarr ty off = case ty of
+ STNil -> return (off, ())
+ STPair a b -> do
+ (off1, x) <- go inarr a off
+ (off2, y) <- go inarr b off1
+ return (off1 + off2, (x, y))
+ STEither a b -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ (off1, val) <- case tag of
+ 0 -> fmap Left <$> go inarr a (off + 1)
+ 1 -> fmap Right <$> go inarr b (off + 1)
+ _ -> error "Invalid tag in accum memory"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
+ else return (off1, val)
+ -- Representation is the same, but add operation is different
+ STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
+ STArr ndim t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# ndim off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = shapeSize sh
+ arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
+ return (off1 + eltsize * n, arr)
+ STScal sty -> goScal sty off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
+ goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
+ goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
+ goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
+ goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
+ goScal STBool off@(I# off#) = do
+ i8 <- readInt8 addr# off#
+ return (off + 1, toEnum (fromIntegral i8))
+
+ in snd <$> go False topty 0
+
+readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
+readShape _ SZ off = return (off, ShNil)
+readShape mbarr (SS ndim) off = do
+ (off1@(I# off1#), sh) <- readShape mbarr ndim off
+ n' <- readInt64 mbarr off1#
+ return (off1 + 8, ShCons sh (fromIntegral n'))
+
+-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
+-- the list.
+data InvShape n where
+ IShNil :: InvShape Z
+ IShCons :: Int -- ^ How many subarrays are there?
+ -> Int -- ^ What is the size of all subarrays together?
+ -> InvShape n -- ^ Sub array inverted shape
+ -> InvShape (S n)
+
+ishSize :: InvShape n -> Int
+ishSize IShNil = 1
+ishSize (IShCons _ sz _) = sz
+
+invertShape :: forall n. Shape n -> InvShape n
+invertShape | Refl <- lemPlusZero @n = flip go IShNil
+ where
+ go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
+ go ShNil ish = ish
+ go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
+
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
+accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
+ go inarr ty SZ () val off = () <$ performAdd inarr ty val off
+ go inarr ty (SS dep) idx val off = case (ty, idx, val) of
+ (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
+ (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
+ (STMaybe t, _, _) -> _ idx val
+ (STArr rank eltty, _, _)
+ | inarr -> error "Nested arrays"
+ | otherwise -> do
+ (off1, ish) <- second invertShape <$> readShape addr# rank off
+ goArr (SS dep) ish eltty idx val off1
+ (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
+ (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
+ (STAccum{}, _, _) -> error "Nested accumulators unsupported"
+
+ goArr :: SNat i' -> InvShape n -> STy t'
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
+ goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
+ goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
+ goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
+ let i' = fromIntegral @(Rep' s TIx) @Int i
+ when (i' < 0 || i' >= n) $
+ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
+ goArr depm1 ish t1 idx val (off + i' * ishSize ish)
+
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
+ performAddArr arraySz eltty val off = do
+ let eltsize = tSize' (Proxy @s) eltty Nothing
+ forM_ [0 .. arraySz - 1] $ \lini ->
+ performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
+ return (off + arraySz * eltsize)
+
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ performAdd inarr ty val off = case ty of
+ STNil -> return off
+ STPair t1 t2 -> do
+ off1 <- performAdd inarr t1 (fst val) off
+ performAdd inarr t2 (snd val) off1
+ STEither t1 t2 -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ off1 <- case (val, tag) of
+ (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
+ (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
+ _ -> error "accumAdd: Tag mismatch for Either"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
+ else return off1
+ STArr n ty'
+ | inarr -> error "Nested array"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# n off
+ performAddArr (shapeSize sh) ty' val off1
+ STScal ty' -> performAddScal ty' val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ performAddScal STI32 (I32# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 4
+ = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
+ | otherwise
+ = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
+ performAddScal STI64 (I64# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 8
+ = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
+ | otherwise
+ = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
+ performAddScal STF32 x off@(I# off#) =
+ off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
+ performAddScal STF64 x off@(I# off#) =
+ off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
+ performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
+
+ casLoop :: Eq w
+ => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
+ -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
+ -> Addr# -- ^ Address to attempt to modify
+ -> (w -> w) -- ^ Operation to apply to the value
+ -> IO ()
+ casLoop readOp casOp addr modify = readOp addr 0# >>= loop
+ where
+ loop value = do
+ value' <- casOp addr value (modify value)
+ if value == value'
+ then return ()
+ else loop value'
+
+ in () <$ go False topty top_depth top_index top_value 0
+
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
+withAccum ty start fun = do
+ -- The initial write must happen before any of the adds or reads, so it makes
+ -- sense to put it in IO together with the allocation, instead of in AcM.
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
+ ptr <- newForeignPtr finalizerFree buffer
+ let accum = Accum ty ptr
+ accumWrite accum start
+ return accum
+ b <- fun accum
+ out <- accumRead accum
+ return (b, out)
+
+inParallel :: [AcM s t] -> AcM s [t]
+inParallel actions = AcM $ do
+ mvars <- mapM (\_ -> newEmptyMVar) actions
+ forM_ (zip actions mvars) $ \(AcM action, var) ->
+ forkIO $ action >>= putMVar var
+ mapM takeMVar mvars
+
+-- | Offset is in bytes.
+readInt8 :: Addr# -> Int# -> IO Int8
+readInt32 :: Addr# -> Int# -> IO Int32
+readInt64 :: Addr# -> Int# -> IO Int64
+readWord32 :: Addr# -> Int# -> IO Word32
+readWord64 :: Addr# -> Int# -> IO Word64
+readFloat :: Addr# -> Int# -> IO Float
+readDouble :: Addr# -> Int# -> IO Double
+readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
+readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
+readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
+readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
+readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
+readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
+readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
+
+writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
+writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
+writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
+writeFloat# :: Addr# -> Int# -> Float# -> IO ()
+writeDouble# :: Addr# -> Int# -> Double# -> IO ()
+writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+
+fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
+fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
+
+atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
+atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
+atomicCasWord32Addr addr (W32# expected) (W32# desired) =
+ IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
+atomicCasWord64Addr addr (W64# expected) (W64# desired) =
+ IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)
diff --git a/src/CHAD/Interpreter/AccumOld.hs b/src/CHAD/Interpreter/AccumOld.hs
new file mode 100644
index 0000000..8e5c040
--- /dev/null
+++ b/src/CHAD/Interpreter/AccumOld.hs
@@ -0,0 +1,366 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+module CHAD.Interpreter.Accum (
+ AcM,
+ runAcM,
+ Rep',
+ Accum,
+ withAccum,
+ accumAdd,
+ inParallel,
+) where
+
+import Control.Concurrent
+import Control.Monad (when, forM_)
+import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
+import Foreign.Storable (sizeOf)
+import GHC.Exts
+import GHC.Float
+import GHC.Int
+import GHC.IO (IO(..))
+import GHC.Word
+import System.IO.Unsafe (unsafePerformIO)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.Data
+
+
+newtype AcM s a = AcM (IO a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
+
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TMaybe t) = Maybe (Rep' s t)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
+
+-- | Floats and integers are accumulated; booleans are left as-is.
+data Accum s t = Accum (STy t) (ForeignPtr ())
+
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
+
+tSize' :: Proxy s -> STy t -> Int
+tSize' p typ = case typ of
+ STNil -> 0
+ STPair a b -> tSize' p a + tSize' p b
+ STEither a b -> 1 + max (tSize' p a) (tSize' p b)
+ -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
+ STMaybe t -> tSize' p (STEither STNil t)
+ STArr ndim t ->
+ case val of
+ Nothing -> error "Nested arrays not supported in this implementation"
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
+ STScal sty -> goScal sty
+ STAccum{} -> error "Nested accumulators unsupported"
+ where
+ goScal :: SScalTy t -> Int
+ goScal STI32 = 4
+ goScal STI64 = 8
+ goScal STF32 = 4
+ goScal STF64 = 8
+ goScal STBool = 1
+
+-- | This operation does not commute with 'accumAdd', so it must be used with
+-- care. Furthermore it must be used on exactly the same value as tSize was
+-- called on. Hence it lives in IO, not in AcM.
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
+accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ go inarr ty val off = case ty of
+ STNil -> return off
+ STPair a b -> do
+ off1 <- go inarr a (fst val) off
+ go inarr b (snd val) off1
+ STEither a b -> do
+ let !(I# off#) = off
+ off1 <- case val of
+ Left x -> do
+ let !(I8# tag#) = 0
+ writeInt8# addr# off# tag#
+ go inarr a x (off + 1)
+ Right y -> do
+ let !(I8# tag#) = 1
+ writeInt8# addr# off# tag#
+ go inarr b y (off + 1)
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
+ else return off1
+ -- Representation is the same, but add operation is different
+ STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
+ STArr _ t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ off1 <- goShape (arrayShape val) off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = arraySize val
+ traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
+ return (off1 + eltsize * n)
+ STScal sty -> goScal sty val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goShape :: Shape n -> Int -> IO Int
+ goShape ShNil off = return off
+ goShape (ShCons sh n) off = do
+ off1@(I# off1#) <- goShape sh off
+ let !(I64# n'#) = fromIntegral n
+ writeInt64# addr# off1# n'#
+ return (off1 + 8)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
+ goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
+ goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
+ goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
+ goScal STBool b off@(I# off#) = do
+ let !(I8# i) = fromIntegral (fromEnum b)
+ off + 1 <$ writeInt8# addr# off# i
+
+ in () <$ go False topty top_value 0
+
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
+accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
+ go inarr ty off = case ty of
+ STNil -> return (off, ())
+ STPair a b -> do
+ (off1, x) <- go inarr a off
+ (off2, y) <- go inarr b off1
+ return (off1 + off2, (x, y))
+ STEither a b -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ (off1, val) <- case tag of
+ 0 -> fmap Left <$> go inarr a (off + 1)
+ 1 -> fmap Right <$> go inarr b (off + 1)
+ _ -> error "Invalid tag in accum memory"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
+ else return (off1, val)
+ -- Representation is the same, but add operation is different
+ STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
+ STArr ndim t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# ndim off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = shapeSize sh
+ arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
+ return (off1 + eltsize * n, arr)
+ STScal sty -> goScal sty off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
+ goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
+ goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
+ goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
+ goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
+ goScal STBool off@(I# off#) = do
+ i8 <- readInt8 addr# off#
+ return (off + 1, toEnum (fromIntegral i8))
+
+ in snd <$> go False topty 0
+
+readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
+readShape _ SZ off = return (off, ShNil)
+readShape mbarr (SS ndim) off = do
+ (off1@(I# off1#), sh) <- readShape mbarr ndim off
+ n' <- readInt64 mbarr off1#
+ return (off1 + 8, ShCons sh (fromIntegral n'))
+
+-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
+-- the list.
+data InvShape n where
+ IShNil :: InvShape Z
+ IShCons :: Int -- ^ How many subarrays are there?
+ -> Int -- ^ What is the size of all subarrays together?
+ -> InvShape n -- ^ Sub array inverted shape
+ -> InvShape (S n)
+
+ishSize :: InvShape n -> Int
+ishSize IShNil = 1
+ishSize (IShCons _ sz _) = sz
+
+invertShape :: forall n. Shape n -> InvShape n
+invertShape | Refl <- lemPlusZero @n = flip go IShNil
+ where
+ go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
+ go ShNil ish = ish
+ go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
+
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
+accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
+ go inarr ty SZ () val off = () <$ performAdd inarr ty val off
+ go inarr ty (SS dep) idx val off = case (ty, idx, val) of
+ (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
+ (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
+ (STMaybe t, _, _) -> _ idx val
+ (STArr rank eltty, _, _)
+ | inarr -> error "Nested arrays"
+ | otherwise -> do
+ (off1, ish) <- second invertShape <$> readShape addr# rank off
+ goArr (SS dep) ish eltty idx val off1
+ (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
+ (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
+ (STAccum{}, _, _) -> error "Nested accumulators unsupported"
+
+ goArr :: SNat i' -> InvShape n -> STy t'
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
+ goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
+ goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
+ goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
+ let i' = fromIntegral @(Rep' s TIx) @Int i
+ when (i' < 0 || i' >= n) $
+ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
+ goArr depm1 ish t1 idx val (off + i' * ishSize ish)
+
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
+ performAddArr arraySz eltty val off = do
+ let eltsize = tSize' (Proxy @s) eltty Nothing
+ forM_ [0 .. arraySz - 1] $ \lini ->
+ performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
+ return (off + arraySz * eltsize)
+
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ performAdd inarr ty val off = case ty of
+ STNil -> return off
+ STPair t1 t2 -> do
+ off1 <- performAdd inarr t1 (fst val) off
+ performAdd inarr t2 (snd val) off1
+ STEither t1 t2 -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ off1 <- case (val, tag) of
+ (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
+ (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
+ _ -> error "accumAdd: Tag mismatch for Either"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
+ else return off1
+ STArr n ty'
+ | inarr -> error "Nested array"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# n off
+ performAddArr (shapeSize sh) ty' val off1
+ STScal ty' -> performAddScal ty' val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ performAddScal STI32 (I32# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 4
+ = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
+ | otherwise
+ = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
+ performAddScal STI64 (I64# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 8
+ = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
+ | otherwise
+ = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
+ performAddScal STF32 x off@(I# off#) =
+ off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
+ performAddScal STF64 x off@(I# off#) =
+ off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
+ performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
+
+ casLoop :: Eq w
+ => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
+ -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
+ -> Addr# -- ^ Address to attempt to modify
+ -> (w -> w) -- ^ Operation to apply to the value
+ -> IO ()
+ casLoop readOp casOp addr modify = readOp addr 0# >>= loop
+ where
+ loop value = do
+ value' <- casOp addr value (modify value)
+ if value == value'
+ then return ()
+ else loop value'
+
+ in () <$ go False topty top_depth top_index top_value 0
+
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
+withAccum ty start fun = do
+ -- The initial write must happen before any of the adds or reads, so it makes
+ -- sense to put it in IO together with the allocation, instead of in AcM.
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
+ ptr <- newForeignPtr finalizerFree buffer
+ let accum = Accum ty ptr
+ accumWrite accum start
+ return accum
+ b <- fun accum
+ out <- accumRead accum
+ return (b, out)
+
+inParallel :: [AcM s t] -> AcM s [t]
+inParallel actions = AcM $ do
+ mvars <- mapM (\_ -> newEmptyMVar) actions
+ forM_ (zip actions mvars) $ \(AcM action, var) ->
+ forkIO $ action >>= putMVar var
+ mapM takeMVar mvars
+
+-- | Offset is in bytes.
+readInt8 :: Addr# -> Int# -> IO Int8
+readInt32 :: Addr# -> Int# -> IO Int32
+readInt64 :: Addr# -> Int# -> IO Int64
+readWord32 :: Addr# -> Int# -> IO Word32
+readWord64 :: Addr# -> Int# -> IO Word64
+readFloat :: Addr# -> Int# -> IO Float
+readDouble :: Addr# -> Int# -> IO Double
+readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
+readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
+readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
+readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
+readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
+readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
+readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
+
+writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
+writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
+writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
+writeFloat# :: Addr# -> Int# -> Float# -> IO ()
+writeDouble# :: Addr# -> Int# -> Double# -> IO ()
+writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+
+fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
+fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
+
+atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
+atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
+atomicCasWord32Addr addr (W32# expected) (W32# desired) =
+ IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
+atomicCasWord64Addr addr (W64# expected) (W64# desired) =
+ IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)
diff --git a/src/CHAD/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs
new file mode 100644
index 0000000..fadc6be
--- /dev/null
+++ b/src/CHAD/Interpreter/Rep.hs
@@ -0,0 +1,105 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.Interpreter.Rep where
+
+import Control.DeepSeq
+import Data.Coerce (coerce)
+import Data.List (intersperse, intercalate)
+import Data.Foldable (toList)
+import Data.IORef
+import GHC.Exts (withDict)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Pretty
+import CHAD.Data
+
+
+type family Rep t where
+ Rep TNil = ()
+ Rep (TPair a b) = (Rep a, Rep b)
+ Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))
+ Rep (TMaybe t) = Maybe (Rep t)
+ Rep (TArr n t) = Array n (Rep t)
+ Rep (TScal sty) = ScalRep sty
+ Rep (TAccum t) = RepAc t
+
+-- Mutable, represents monoid types t.
+type family RepAc t where
+ RepAc TNil = ()
+ RepAc (TPair a b) = (RepAc a, RepAc b)
+ RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))
+ RepAc (TMaybe t) = IORef (Maybe (RepAc t))
+ RepAc (TArr n t) = Array n (RepAc t)
+ RepAc (TScal sty) = IORef (ScalRep sty)
+
+newtype Value t = Value { unValue :: Rep t }
+
+liftV :: (Rep a -> Rep b) -> Value a -> Value b
+liftV f (Value x) = Value (f x)
+
+liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
+liftV2 f (Value x) (Value y) = Value (f x y)
+
+vPair :: Value a -> Value b -> Value (TPair a b)
+vPair = liftV2 (,)
+
+vUnpair :: Value (TPair a b) -> (Value a, Value b)
+vUnpair (Value (x, y)) = (Value x, Value y)
+
+showValue :: Int -> STy t -> Rep t -> ShowS
+showValue _ STNil () = showString "()"
+showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
+showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x
+showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y
+showValue _ (STLEither _ _) Nothing = showString "LNil"
+showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
+showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y
+showValue _ (STMaybe _) Nothing = showString "Nothing"
+showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
+showValue d (STArr _ t) arr = showParen (d > 10) $
+ showString "arrayFromList " . showsPrec 11 (arrayShape arr)
+ . showString " ["
+ . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
+ . showString "]"
+showValue d (STScal sty) x = case sty of
+ STF32 -> showsPrec d x
+ STF64 -> showsPrec d x
+ STI32 -> showsPrec d x
+ STI64 -> showsPrec d x
+ STBool -> showsPrec d x
+showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
+
+showEnv :: SList STy env -> SList Value env -> String
+showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
+ where
+ showEntries :: SList STy env -> SList Value env -> [String]
+ showEntries SNil SNil = []
+ showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
+
+rnfRep :: STy t -> Rep t -> ()
+rnfRep STNil () = ()
+rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y
+rnfRep (STEither a _) (Left x) = rnfRep a x
+rnfRep (STEither _ b) (Right y) = rnfRep b y
+rnfRep (STLEither _ _) Nothing = ()
+rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
+rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
+rnfRep (STMaybe _) Nothing = ()
+rnfRep (STMaybe t) (Just x) = rnfRep t x
+rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr =
+ withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
+rnfRep (STScal t) x = case t of
+ STI32 -> rnf x
+ STI64 -> rnf x
+ STF32 -> rnf x
+ STF64 -> rnf x
+ STBool -> rnf x
+rnfRep STAccum{} _ = error "Cannot rnf accumulators"
+
+instance KnownTy t => NFData (Value t) where
+ rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs
new file mode 100644
index 0000000..6dc91a5
--- /dev/null
+++ b/src/CHAD/Language.hs
@@ -0,0 +1,266 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitForAll #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+module CHAD.Language (
+ fromNamed,
+ NExpr,
+ Ex,
+ module CHAD.Language,
+ module CHAD.AST.Types,
+ Lookup,
+) where
+
+import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Sparse.Types
+import CHAD.AST.Types
+import CHAD.Data
+import CHAD.Drev.Types
+import CHAD.Language.AST
+
+
+data a :-> b = a :-> b
+ deriving (Show)
+infixr 0 :->
+
+
+body :: NExpr env t -> NFun env env t
+body = NBody
+
+lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
+lambda = NLam
+
+inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
+inline = inlineNFun
+
+-- To be used to construct the argument list for 'inline'.
+--
+-- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y
+-- > in inline fun (SNil .$ 16 .$ 26)
+(.$) :: SList f list -> f a -> SList f (a : list)
+(.$) = flip SCons
+
+
+let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
+let_ = NELet
+
+pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
+pair = NEPair
+
+fst_ :: NExpr env (TPair a b) -> NExpr env a
+fst_ = NEFst
+
+snd_ :: NExpr env (TPair a b) -> NExpr env b
+snd_ = NESnd
+
+nil :: NExpr env TNil
+nil = NENil
+
+inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b)
+inl = NEInl knownTy
+
+inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b)
+inr = NEInr knownTy
+
+case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
+case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
+
+nothing :: KnownTy a => NExpr env (TMaybe a)
+nothing = NENothing knownTy
+
+just :: NExpr env a -> NExpr env (TMaybe a)
+just = NEJust
+
+maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b
+maybe_ a (v :-> b) c = NEMaybe a v b c
+
+constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
+constArr_ x =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConstArr knownNat ty x
+
+build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
+build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
+
+build2 :: NExpr env TIx -> NExpr env TIx
+ -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
+ -> NExpr env (TArr (S (S Z)) t)
+build2 a1 a2 (v1 :-> v2 :-> b) =
+ NEBuild (SS (SS SZ))
+ (pair (pair nil a1) a2)
+ #idx
+ (let_ v1 (snd_ (fst_ #idx)) $
+ let_ v2 (NEDrop SZ (snd_ #idx)) $
+ NEDrop (SS (SS SZ)) b)
+
+build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
+build n a (v :-> b) = NEBuild n a v b
+
+map_ :: forall n a b env name. (KnownNat n, KnownTy a)
+ => (Var name a :-> NExpr ('(name, a) : env) b)
+ -> NExpr env (TArr n a) -> NExpr env (TArr n b)
+map_ (v :-> a) b = NEMap v a b
+
+fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t t)
+ in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3
+
+sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+sum1i e = NESum1Inner e
+
+unit :: NExpr env t -> NExpr env (TArr Z t)
+unit = NEUnit
+
+replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
+replicate1i n a = NEReplicate1Inner n a
+
+maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+maximum1i e = NEMaximum1Inner e
+
+minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+minimum1i e = NEMinimum1Inner e
+
+reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+reshape = NEReshape
+
+fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t1 t1)
+ in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3
+
+fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
+ -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
+
+const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
+const_ x =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty x
+
+idx0 :: NExpr env (TArr Z t) -> NExpr env t
+idx0 = NEIdx0
+
+-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
+-- (.!) = NEIdx1
+-- infixl 9 .!
+
+(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+(!) = NEIdx
+infixl 9 !
+
+shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
+shape = NEShape
+
+length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
+length_ e = snd_ (shape e)
+
+oper :: SOp a t -> NExpr env a -> NExpr env t
+oper = NEOp
+
+oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t
+oper2 op a b = NEOp op (pair a b)
+
+error_ :: KnownTy t => String -> NExpr env t
+error_ s = NEError knownTy s
+
+custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
+ -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
+ -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
+ NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+
+recompute :: NExpr env a -> NExpr env a
+recompute = NERecompute
+
+with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
+with a (n :-> b) = NEWith (knownMTy @t) a n b
+
+accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+
+accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+accumS p a sp b c = NEAccum knownMTy p a sp b c
+
+
+(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+a .== b = oper (OEq knownScalTy) (pair a b)
+infix 4 .==
+
+(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+a .< b = oper (OLt knownScalTy) (pair a b)
+infix 4 .<
+
+(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.>) = flip (.<)
+infix 4 .>
+
+(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+a .<= b = oper (OLe knownScalTy) (pair a b)
+infix 4 .<=
+
+(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.>=) = flip (.<=)
+infix 4 .>=
+
+not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+not_ = oper ONot
+
+and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+and_ = oper2 OAnd
+infixr 3 `and_`
+
+or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+or_ = oper2 OOr
+infixr 2 `or_`
+
+mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a)
+mod_ = oper2 (OMod knownScalTy)
+infixl 7 `mod_`
+
+-- | The first alternative is the True case; the second is the False case.
+if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
+if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
+
+round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64)
+round_ = oper ORound64
+
+toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64)
+toFloat_ = oper OToFl64
+
+idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t)
+idiv = oper2 (OIDiv knownScalTy)
+infixl 7 `idiv`
diff --git a/src/CHAD/Language/AST.hs b/src/CHAD/Language/AST.hs
new file mode 100644
index 0000000..b270844
--- /dev/null
+++ b/src/CHAD/Language/AST.hs
@@ -0,0 +1,300 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.Language.AST where
+
+import Data.Kind (Type)
+import Data.Type.Equality
+import GHC.OverloadedLabels
+import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Sparse.Types
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+type NExpr :: [(Symbol, Ty)] -> Ty -> Type
+data NExpr env t where
+ -- lambda calculus
+ NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
+ NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
+
+ -- environment management
+ NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t
+
+ -- base types
+ NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
+ NEFst :: NExpr env (TPair a b) -> NExpr env a
+ NESnd :: NExpr env (TPair a b) -> NExpr env b
+ NENil :: NExpr env TNil
+ NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
+ NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
+ NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c
+ NENothing :: STy t -> NExpr env (TMaybe t)
+ NEJust :: NExpr env t -> NExpr env (TMaybe t)
+ NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b
+
+ -- array operations
+ NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
+ NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
+ NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEUnit :: NExpr env t -> NExpr env (TArr Z t)
+ NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
+ NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+
+ NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b)
+ -> NExpr env t1
+ -> NExpr env (TArr (S n) t1)
+ -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+ NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
+ -> NExpr env (TArr (S n) b)
+ -> NExpr env (TArr n t2)
+ -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+
+ -- expression operations
+ NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
+ NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
+ NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
+ NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+ NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
+ NEOp :: SOp a t -> NExpr env a -> NExpr env t
+
+ -- custom derivatives
+ NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation
+ -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+
+ -- fake halfway checkpointing
+ NERecompute :: NExpr env t -> NExpr env t
+
+ -- accumulation effect on monoids
+ NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+
+ -- partiality
+ NEError :: STy a -> String -> NExpr env a
+
+ -- embedded unnamed expressions
+ NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
+deriving instance Show (NExpr env t)
+
+type Lookup name env = Lookup1 (name == "_") name env
+type family Lookup1 eqblank name env where
+ Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
+ Lookup1 False name env = Lookup2 name env
+type family Lookup2 name env where
+ Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
+ Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env
+type family Lookup3 eq t name env where
+ Lookup3 True t _ _ = t
+ Lookup3 False _ name env = Lookup2 name env
+
+type family DropNth i env where
+ DropNth Z (_ : env) = env
+ DropNth (S i) (p : env) = p : DropNth i env
+
+data Var name t = Var (SSymbol name) (STy t)
+ deriving (Show)
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
+ a + b = NEOp (OAdd knownScalTy) (NEPair a b)
+ a * b = NEOp (OMul knownScalTy) (NEPair a b)
+ negate e = NEOp (ONeg knownScalTy) e
+ abs = error "abs undefined for NExpr"
+ signum = error "signum undefined for NExpr"
+ fromInteger =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty . fromInteger
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
+ => Fractional (NExpr env t) where
+ recip e = NEOp (ORecip knownScalTy) e
+ fromRational =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty . fromRational
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
+ => Floating (NExpr env t) where
+ pi =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty pi
+ exp = NEOp (OExp knownScalTy)
+ log = NEOp (OExp knownScalTy)
+ sin = undefined ; cos = undefined ; tan = undefined
+ asin = undefined ; acos = undefined ; atan = undefined
+ sinh = undefined ; cosh = undefined
+ asinh = undefined ; acosh = undefined ; atanh = undefined
+
+instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
+ fromLabel = Var symbolSing knownTy
+
+instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
+ fromLabel = NEVar (fromLabel @name)
+
+-- | Innermost variable variable on the outside, on the right.
+data NEnv env where
+ NTop :: NEnv '[]
+ NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)
+
+-- | First (outermost) parameter on the outside, on the left.
+-- * env: environment of this function (grows as you go deeper inside lambdas)
+-- * env': environment of the body of the function
+-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
+data NFun env env' t where
+ NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
+ NBody :: NExpr env' t -> NFun env' env' t
+
+type family UnName env where
+ UnName '[] = '[]
+ UnName ('(name, t) : env) = t : UnName env
+
+envFromNEnv :: NEnv env -> SList STy (UnName env)
+envFromNEnv NTop = SNil
+envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env
+
+inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
+inlineNFun fun args = NEUnnamed (fromNamed fun) args
+
+fromNamed :: NFun '[] env t -> Ex (UnName env) t
+fromNamed = fromNamedFun NTop
+
+-- | Some of the parameters have already been put in the environment; some
+-- haven't. Transfer all parameters to the left into the environment.
+--
+-- [] `fromNamedFun` λx y z. E
+-- = []:x `fromNamedFun` λy z. E
+-- = []:x:y `fromNamedFun` λz. E
+-- = []:x:y:z `fromNamedFun` λ. E
+-- = []:x:y:z `fromNamedExpr` E
+fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
+fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
+fromNamedFun env (NBody e) = fromNamedExpr env e
+
+fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
+fromNamedExpr val = \case
+ NEVar var@(Var _ ty)
+ | Just idx <- find var val -> EVar ext ty idx
+ | otherwise -> error "Variable out of scope in conversion from surface \
+ \expression to De Bruijn expression"
+ NELet n a b -> ELet ext (go a) (lambda val n b)
+
+ NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)
+
+ NEPair a b -> EPair ext (go a) (go b)
+ NEFst e -> EFst ext (go e)
+ NESnd e -> ESnd ext (go e)
+ NENil -> ENil ext
+ NEInl t e -> EInl ext t (go e)
+ NEInr t e -> EInr ext t (go e)
+ NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
+ NENothing t -> ENothing ext t
+ NEJust e -> EJust ext (go e)
+ NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c)
+
+ NEConstArr n t x -> EConstArr ext n t x
+ NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
+ NEMap n a b -> EMap ext (lambda val n a) (go b)
+ NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
+ NESum1Inner e -> ESum1Inner ext (go e)
+ NEUnit e -> EUnit ext (go e)
+ NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
+ NEMaximum1Inner e -> EMaximum1Inner ext (go e)
+ NEMinimum1Inner e -> EMinimum1Inner ext (go e)
+ NEReshape n a b -> EReshape ext n (go a) (go b)
+
+ NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c)
+ NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+
+ NEConst t x -> EConst ext t x
+ NEIdx0 e -> EIdx0 ext (go e)
+ NEIdx1 a b -> EIdx1 ext (go a) (go b)
+ NEIdx a b -> EIdx ext (go a) (go b)
+ NEShape e -> EShape ext (go e)
+ NEOp op e -> EOp ext op (go e)
+
+ NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
+ ECustom ext ta tb ttape
+ (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
+ (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
+ (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
+ (go e1) (go e2)
+ NERecompute e -> ERecompute ext (go e)
+
+ NEWith t a n b -> EWith ext t (go a) (lambda val n b)
+ NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
+
+ NEError t s -> EError ext t s
+
+ NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
+ where
+ go :: NExpr env t' -> Ex (UnName env) t'
+ go = fromNamedExpr val
+
+ find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
+ find _ NTop = Nothing
+ find var@(Var s ty) (val' `NPush` Var s' ty')
+ | Just Refl <- testEquality s s'
+ , Just Refl <- testEquality ty ty'
+ = Just IZ
+ | otherwise
+ = IS <$> find var val'
+
+ lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
+ lambda val' var e = fromNamedExpr (val' `NPush` var) e
+
+ lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
+ lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e
+
+ injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
+ injectWrapLet e SNil = e
+ injectWrapLet e (arg `SCons` args) =
+ injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
+ args
+
+dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
+dropNth SZ (val `NPush` _) = val
+dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
+dropNth _ NTop = error "DropNth: index out of range"
+
+dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
+dropNthW SZ (_ `NPush` _) = WSink
+dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
+dropNthW _ NTop = error "DropNth: index out of range"
+
+assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r
+assertSymbolNotUnderscore s@SSymbol k =
+ case symbolVal s of
+ "_" -> error "assertSymbolNotUnderscore: was underscore"
+ _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k
+
+assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r
+assertSymbolDistinct s1@SSymbol s2@SSymbol k
+ | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")"
+ | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k
+
+equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r
+equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k
diff --git a/src/CHAD/Lemmas.hs b/src/CHAD/Lemmas.hs
new file mode 100644
index 0000000..55ef042
--- /dev/null
+++ b/src/CHAD/Lemmas.hs
@@ -0,0 +1,21 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+{-# LANGUAGE AllowAmbiguousTypes #-}
+module CHAD.Lemmas (module CHAD.Lemmas, (:~:)(Refl)) where
+
+import Data.Type.Equality
+import Unsafe.Coerce (unsafeCoerce)
+
+
+type family Append a b where
+ Append '[] l = l
+ Append (x : xs) l = x : Append xs l
+
+lemAppendNil :: Append a '[] :~: a
+lemAppendNil = unsafeCoerce Refl
+
+lemAppendAssoc :: Append a (Append b c) :~: Append (Append a b) c
+lemAppendAssoc = unsafeCoerce Refl
diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs
new file mode 100644
index 0000000..2510cc5
--- /dev/null
+++ b/src/CHAD/Simplify.hs
@@ -0,0 +1,619 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Simplify (
+ simplifyN, simplifyFix,
+ SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
+) where
+
+import Control.Monad (ap)
+import Data.Bifunctor (first)
+import Data.Function (fix)
+import Data.Monoid (Any(..))
+
+import Debug.Trace
+
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Pretty
+import CHAD.AST.Sparse.Types
+import CHAD.AST.UnMonoid (acPrjCompose)
+import CHAD.Data
+import CHAD.Simplify.TH
+
+
+data SimplifyConfig = SimplifyConfig
+ { scLogging :: Bool
+ }
+
+defaultSimplifyConfig :: SimplifyConfig
+defaultSimplifyConfig = SimplifyConfig False
+
+simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
+simplifyN 0 = id
+simplifyN n = simplifyN (n - 1) . simplify
+
+simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
+simplify =
+ let ?accumInScope = checkAccumInScope @env knownEnv
+ ?config = defaultSimplifyConfig
+ in snd . runSM . simplify'
+
+simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
+simplifyWith config =
+ let ?accumInScope = checkAccumInScope @env knownEnv
+ ?config = config
+ in snd . runSM . simplify'
+
+simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
+simplifyFix = simplifyFixWith defaultSimplifyConfig
+
+simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
+simplifyFixWith config =
+ let ?accumInScope = checkAccumInScope @env knownEnv
+ ?config = config
+ in fix $ \loop e ->
+ let (act, e') = runSM (simplify' e)
+ in if act then loop e' else e'
+
+-- | simplify monad
+newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a))
+ deriving (Functor)
+
+instance Applicative (SM tenv tt env t) where
+ pure x = SM (\_ -> (Any False, x))
+ (<*>) = ap
+
+instance Monad (SM tenv tt env t) where
+ SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx
+
+runSM :: SM env t env t a -> (Bool, a)
+runSM (SM f) = first getAny (f id)
+
+smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
+smReconstruct core = SM (\ctx -> (Any False, ctx core))
+
+class Monad m => ActedMonad m where
+ tellActed :: m ()
+ hideActed :: m a -> m a
+ liftActed :: (Any, a) -> m a
+
+instance ActedMonad ((,) Any) where
+ tellActed = (Any True, ())
+ hideActed (_, x) = (Any False, x)
+ liftActed = id
+
+instance ActedMonad (SM tenv tt env t) where
+ tellActed = SM (\_ -> tellActed)
+ hideActed (SM f) = SM (\ctx -> hideActed (f ctx))
+ liftActed pair = SM (\_ -> pair)
+
+-- more convenient in practice
+acted :: ActedMonad m => m a -> m a
+acted m = tellActed >> m
+
+within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
+within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
+
+simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify' expr
+ | scLogging ?config = do
+ res <- simplify'Rec expr
+ full <- smReconstruct res
+ let printed = ppExpr knownEnv full
+ replace a bs = concatMap (\x -> if x == a then bs else [x])
+ str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed
+ | otherwise = "--- simplify step: " ++ printed
+ traceM str
+ return res
+ | otherwise = simplify'Rec expr
+
+simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify'Rec = \case
+ -- inlining
+ ELet _ rhs body
+ | cheapExpr rhs
+ -> acted $ simplify' (substInline rhs body)
+
+ | Occ lexOcc runOcc <- occCount IZ body
+ , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply
+ || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not
+ -> acted $ simplify' (substInline rhs body)
+
+ -- let splitting / let peeling
+ ELet _ (EPair _ a b) body ->
+ acted $ simplify' $
+ ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
+ IS i -> EVar ext t (IS (IS i)))
+ body
+ ELet _ (EJust _ a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body
+ ELet _ (EInl _ t2 a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body
+ ELet _ (EInr _ t1 a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body
+
+ -- let rotation
+ ELet _ (ELet _ rhs a) b -> do
+ b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b
+ acted $ simplify' $
+ ELet ext rhs $
+ ELet ext a $
+ weakenExpr (WCopy WSink) b'
+
+ -- beta rules for products
+ EFst _ (EPair _ e e')
+ | not (hasAdds e') -> acted $ simplify' e
+ | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e)
+ ESnd _ (EPair _ e' e)
+ | not (hasAdds e') -> acted $ simplify' e
+ | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e)
+
+ -- beta rules for coproducts
+ ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs)
+ ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs)
+
+ -- beta rules for maybe
+ EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
+ EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1
+
+ -- let floating
+ EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
+ ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
+ ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
+ EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
+ EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
+ EAccum _ t p e1 sp (ELet _ rhs body) acc ->
+ acted $ simplify' $
+ ELet ext rhs $
+ EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc)
+
+ -- let () = e in () ~> e
+ ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
+ acted $ simplify' e1
+
+ -- map (\_ -> x) e ~> build (shape e) (\_ -> x)
+ EMap _ e1 e2
+ | Occ Zero Zero <- occCount IZ e1
+ , STArr n _ <- typeOf e2 ->
+ acted $ simplify' $
+ EBuild ext n (EShape ext e2) $
+ subst (\_ t' -> \case IZ -> error "Unused variable was used"
+ IS i -> EVar ext t' (IS i))
+ e1
+
+ -- vertical fusion
+ EMap _ e1 (EMap _ e2 e3) ->
+ acted $ simplify' $
+ EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3
+
+ -- projection down-commuting
+ EFst _ (ECase _ e1 e2 e3) ->
+ acted $ simplify' $
+ ECase ext e1 (EFst ext e2) (EFst ext e3)
+ ESnd _ (ECase _ e1 e2 e3) ->
+ acted $ simplify' $
+ ECase ext e1 (ESnd ext e2) (ESnd ext e3)
+ EFst _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (EFst ext e1) (EFst ext e2) e3
+ ESnd _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
+
+ -- TODO: more array indexing
+ EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2
+ EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1
+ EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
+ EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1
+
+ -- TODO: more array shape
+ EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1
+ EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2)
+
+ -- TODO: more constant folding
+ EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
+
+ -- inline cheap array constructors
+ ELet _ (EReplicate1Inner _ e1 e2) e3 ->
+ acted $ simplify' $
+ ELet ext (EPair ext e1 e2) $
+ let v = EVar ext (STPair tIx (typeOf e2)) IZ
+ in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3
+ -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is
+ -- -- cheap, which it can't be because (!) is not cheap if you do AD after.
+ -- -- Should do proper SoA representation.
+ -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 ->
+ -- acted $ simplify' $
+ -- ELet ext e1 $
+ -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3
+
+ -- eta rule for unit
+ e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
+ case e of
+ ENil _ -> return e
+ _ -> acted $ return (ENil ext)
+
+ EBuild _ SZ _ e ->
+ acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
+
+ -- monoid rules
+ EAccum _ t p e1 sp e2 acc -> do
+ e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc
+ simplifyOHT (OneHotTerm SAID t p e1' sp e2')
+ (acted $ return (ENil ext))
+ (\sp' (InContext w wrap e) -> do
+ e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e
+ return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')))
+ (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do
+ -- The acted management here is a hideous mess.
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2''
+ return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')))
+ EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
+ EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
+ EOneHot _ t p e1 e2 -> do
+ e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
+ e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
+ simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2')
+ (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
+ (\sp' (InContext _ wrap e) ->
+ case isDense t sp' of
+ Just Refl -> do
+ e' <- hideActed $ within wrap $ simplify' e
+ return (wrap e')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
+ (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) ->
+ case isDense (acPrjTy p' t') sp' of
+ Just Refl -> do
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2''
+ return (wrap $ EOneHot ext t' p' e1''' e2''')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
+
+ -- type-specific equations for plus
+ EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
+ acted $ return (ENil ext)
+
+ EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) ->
+ acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)
+
+ EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) ->
+ acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2)
+ EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) ->
+ acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2)
+ EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e
+ EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e
+
+ EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) ->
+ acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
+ EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
+ EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
+
+ -- fallback recursion
+ EVar _ t i -> pure $ EVar ext t i
+ ELet _ a b -> [simprec| ELet ext *a *b |]
+ EPair _ a b -> [simprec| EPair ext *a *b |]
+ EFst _ e -> [simprec| EFst ext *e |]
+ ESnd _ e -> [simprec| ESnd ext *e |]
+ ENil _ -> pure $ ENil ext
+ EInl _ t e -> [simprec| EInl ext t *e |]
+ EInr _ t e -> [simprec| EInr ext t *e |]
+ ECase _ e a b -> [simprec| ECase ext *e *a *b |]
+ ENothing _ t -> pure $ ENothing ext t
+ EJust _ e -> [simprec| EJust ext *e |]
+ EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]
+ ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
+ ELInl _ t e -> [simprec| ELInl ext t *e |]
+ ELInr _ t e -> [simprec| ELInr ext t *e |]
+ ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
+ EConstArr _ n t v -> pure $ EConstArr ext n t v
+ EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
+ EMap _ a b -> [simprec| EMap ext *a *b |]
+ EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
+ ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
+ EUnit _ e -> [simprec| EUnit ext *e |]
+ EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
+ EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
+ EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
+ EReshape _ n a b -> [simprec| EReshape ext n *a *b |]
+ EZip _ a b -> [simprec| EZip ext *a *b |]
+ EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
+ EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]
+ EConst _ t v -> pure $ EConst ext t v
+ EIdx0 _ e -> [simprec| EIdx0 ext *e |]
+ EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
+ EIdx _ a b -> [simprec| EIdx ext *a *b |]
+ EShape _ e -> [simprec| EShape ext *e |]
+ EOp _ op e -> [simprec| EOp ext op *e |]
+ ECustom _ s t p a b c e1 e2 -> do
+ a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a)
+ b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b)
+ c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c)
+ e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
+ pure (ECustom ext s t p a' b' c' e1' e2')
+ ERecompute _ e -> [simprec| ERecompute ext *e |]
+ EWith _ t e1 e2 -> do
+ e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
+ pure (EWith ext t e1' e2')
+ -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |]
+ -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |]
+ EZero _ t e -> [simprec| EZero ext t *e |]
+ EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
+ EError _ t s -> pure $ EError ext t s
+
+-- | This can be made more precise by tracking (and not counting) adds on
+-- locally eliminated accumulators.
+hasAdds :: Expr x env t -> Bool
+hasAdds = \case
+ EVar _ _ _ -> False
+ ELet _ rhs body -> hasAdds rhs || hasAdds body
+ EPair _ a b -> hasAdds a || hasAdds b
+ EFst _ e -> hasAdds e
+ ESnd _ e -> hasAdds e
+ ENil _ -> False
+ EInl _ _ e -> hasAdds e
+ EInr _ _ e -> hasAdds e
+ ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
+ ENothing _ _ -> False
+ EJust _ e -> hasAdds e
+ EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
+ ELNil _ _ _ -> False
+ ELInl _ _ e -> hasAdds e
+ ELInr _ _ e -> hasAdds e
+ ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c
+ EConstArr _ _ _ _ -> False
+ EBuild _ _ a b -> hasAdds a || hasAdds b
+ EMap _ a b -> hasAdds a || hasAdds b
+ EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ ESum1Inner _ e -> hasAdds e
+ EUnit _ e -> hasAdds e
+ EReplicate1Inner _ a b -> hasAdds a || hasAdds b
+ EMaximum1Inner _ e -> hasAdds e
+ EMinimum1Inner _ e -> hasAdds e
+ EReshape _ _ a b -> hasAdds a || hasAdds b
+ EZip _ a b -> hasAdds a || hasAdds b
+ EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
+ EConst _ _ _ -> False
+ EIdx0 _ e -> hasAdds e
+ EIdx1 _ a b -> hasAdds a || hasAdds b
+ EIdx _ a b -> hasAdds a || hasAdds b
+ EShape _ e -> hasAdds e
+ EOp _ _ e -> hasAdds e
+ EWith _ _ a b -> hasAdds a || hasAdds b
+ ERecompute _ e -> hasAdds e
+ EAccum _ _ _ _ _ _ _ -> True
+ EZero _ _ e -> hasAdds e
+ EDeepZero _ _ e -> hasAdds e
+ EPlus _ _ a b -> hasAdds a || hasAdds b
+ EOneHot _ _ _ a b -> hasAdds a || hasAdds b
+ EError _ _ _ -> False
+
+checkAccumInScope :: SList STy env -> Bool
+checkAccumInScope = \case SNil -> False
+ SCons t env -> check t || checkAccumInScope env
+ where
+ check :: STy t -> Bool
+ check STNil = False
+ check (STPair s t) = check s || check t
+ check (STEither s t) = check s || check t
+ check (STLEither s t) = check s || check t
+ check (STMaybe t) = check t
+ check (STArr _ t) = check t
+ check (STScal _) = False
+ check STAccum{} = True
+
+data OneHotTerm dense env a where
+ OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a
+deriving instance Show (OneHotTerm dense env a)
+
+data InContext f env (a :: Ty) where
+ InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a
+
+simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do
+ val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val
+ return $ OneHotTerm dense t prj idx sp val'
+
+simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a)
+simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) =
+ unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 ->
+ acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' ->
+ return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2)
+simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht
+
+simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val))
+ | Just Refl <- isDense (acPrjTy prj1 t1) sp =
+ let idx2' :: Ex env (AcIdx dense p2 c)
+ idx2' = case dense of
+ SAID -> reduceAcIdx t2 prj2 idx2
+ SAIS -> idx2
+ in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' ->
+ acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val
+simplifyOHT_concat oht = return oht
+
+-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is
+-- -- dense, then the Sparse in the output will also be dense. This property is
+-- -- used when simplifying EOneHot, which cannot represent sparsity.
+simplifyOHT :: ActedMonad m => OneHotTerm dense env a
+ -> m r -- ^ Zero case (onehot is actually zero)
+ -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot)
+ -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified
+ -> m r
+simplifyOHT oht kzero ktriv k = do
+ -- traceM $ "sOHT: input " ++ show oht
+ oht1 <- simplifyOHT_recogniseMonoid oht
+ -- traceM $ "sOHT: recog " ++ show oht1
+ InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1
+ -- traceM $ "sOHT: unspa " ++ show oht2
+ oht3 <- simplifyOHT_concat oht2
+ -- traceM $ "sOHT: conca " ++ show oht3
+ -- traceM ""
+ case oht3 of
+ OneHotTerm _ _ _ _ _ EZero{} -> kzero
+ OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val)
+ _ -> k (InContext w1 wrap1 oht3)
+
+-- Sets the acted flag whenever a non-trivial projection is returned or the
+-- output Sparse is different from the input Sparse.
+unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a'
+ -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s)
+ -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r
+unsparseOneHotD topsp topval k = case (topsp, topval) of
+ -- eliminate always-Just sparse onehot
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k
+
+ -- expand the top levels of a onehot for a sparse type into a onehot for the
+ -- corresponding non-sparse type
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPFst spprj) idx' s1' e'
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPSnd spprj) idx' s1' e'
+ (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPLeft spprj) idx' s1' e'
+ (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPRight spprj) idx' s1' e'
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPJust spprj) idx' s1' e'
+ (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val)
+ | Dict <- styKnown (typeOf idx) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' ->
+ acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e'
+
+ -- anything else we don't know how to improve
+ _ -> k WId id SAPHere (ENil ext) topsp topval
+
+{-
+unsparseOneHotS :: ActedMonad m
+ => Sparse a a' -> Ex env a'
+ -> (forall b. Sparse a b -> Ex env b -> m r) -> m r
+unsparseOneHotS topsp topval k = case (topsp, topval) of
+ -- order is relevant to make sure we set the acted flag correctly
+ (SpAbsent, v@ENil{}) -> k SpAbsent v
+ (SpAbsent, v@EZero{}) -> k SpAbsent v
+ (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+
+ -- the unsparsifying
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k
+
+ -- recursion
+ -- TODO: coproducts could safely become projections as they do not need
+ -- zeroinfo. But that would only work if the coproduct is at the top, because
+ -- as soon as we hit a product, we need zeroinfo to make it a projection and
+ -- we don't have that.
+ (SpSparse s, e) -> k (SpSparse s) e
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' ->
+ acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext))
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' ->
+ acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do
+ case s2 of SpAbsent -> pure () ; _ -> tellActed
+ k (SpLEither s1' SpAbsent) (ELInl ext STNil e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do
+ case s1 of SpAbsent -> pure () ; _ -> tellActed
+ acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e')
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' ->
+ k (SpMaybe s1') (EJust ext e')
+ (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' ->
+ k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e')
+ _ -> _
+-}
+
+-- | Recognises 'EZero' and 'EOneHot'.
+recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
+recogniseMonoid _ e@EOneHot{} = return e
+recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext)
+recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
+ ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
+ (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2)
+ (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
+ (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
+ (a', b') -> return $ EPair ext a' b'
+recogniseMonoid typ@(SMTLEither t1 t2) expr =
+ case expr of
+ ELNil{} -> acted $ return $ EZero ext typ (ENil ext)
+ ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
+ _ -> return expr
+recogniseMonoid typ@(SMTMaybe t1) expr =
+ case expr of
+ ENothing{} -> acted $ return $ EZero ext typ (ENil ext)
+ EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ _ -> return expr
+recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
+ acted $ do
+ e' <- recogniseMonoid t e
+ return $
+ ELet ext e' $
+ EOneHot ext typ (SAPArrIdx SAPHere)
+ (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ))))
+ (ENil ext))
+ (EVar ext (fromSMTy t) IZ)
+recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
+ (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ _ -> return e
+recogniseMonoid _ e = return e
+
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a)
+reduceAcIdx topty topprj e = case (topty, topprj) of
+ (_, SAPHere) -> ENil ext
+ (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
+ (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
+ (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e
+ (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e
+ (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e
+ (SMTArr _ t, SAPArrIdx p) ->
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (efst e1) (reduceAcIdx t p e2)
+
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
+zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
+ where
+ -- invariant: AcIdx expression is duplicable
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go t SAPHere _ e = makeZeroInfo t e
+ go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
+ go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
+ go SMTLEither{} _ _ _ = ENil ext
+ go SMTMaybe{} _ _ _ = ENil ext
+ go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)
diff --git a/src/CHAD/Simplify/TH.hs b/src/CHAD/Simplify/TH.hs
new file mode 100644
index 0000000..4af5394
--- /dev/null
+++ b/src/CHAD/Simplify/TH.hs
@@ -0,0 +1,80 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module CHAD.Simplify.TH (simprec) where
+
+import Data.Bifunctor (first)
+import Data.Char
+import Data.List (foldl', foldl1')
+import Language.Haskell.TH
+import Language.Haskell.TH.Quote
+import Text.ParserCombinators.ReadP
+
+
+-- [simprec| EPair ext *a *b |]
+-- ~>
+-- do a' <- within (\a' -> EPair ext a' b) (simplify' a)
+-- b' <- within (\b' -> EPair ext a' b') (simplify' b)
+-- pure (EPair ext a' b')
+
+simprec :: QuasiQuoter
+simprec = QuasiQuoter
+ { quoteDec = \_ -> fail "simprec used outside of expression context"
+ , quoteType = \_ -> fail "simprec used outside of expression context"
+ , quoteExp = handler
+ , quotePat = \_ -> fail "simprec used outside of expression context"
+ }
+
+handler :: String -> Q Exp
+handler str =
+ case readP_to_S pTemplate str of
+ [(template, "")] -> generate template
+ _:_:_ -> fail "simprec: template grammar ambiguous"
+ _ -> fail "simprec: could not parse template"
+
+generate :: Template -> Q Exp
+generate (Template topitems) =
+ let takePrefix (Plain x : xs) = first (x:) (takePrefix xs)
+ takePrefix xs = ([], xs)
+
+ itemVar "" = error "simprec: empty item name?"
+ itemVar name@(c:_) | isLower c = VarE (mkName name)
+ | isUpper c = ConE (mkName name)
+ | otherwise = error "simprec: non-letter item name?"
+
+ loop :: Exp -> [Item] -> Q [Stmt]
+ loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)]
+ loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs
+ loop yet (Recurse x : xs) = do
+ primeName <- newName (x ++ "'")
+ let appPrePrime e (Plain y) = e `AppE` itemVar y
+ appPrePrime e (Recurse y) = e `AppE` itemVar y
+ let stmt = BindS (VarP primeName) $
+ VarE (mkName "within")
+ `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs)
+ `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x))
+ stmts <- loop (yet `AppE` VarE primeName) xs
+ return (stmt : stmts)
+
+ (prefix, items') = takePrefix topitems
+ in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items'
+
+data Template = Template [Item]
+ deriving (Show)
+
+data Item = Plain String | Recurse String
+ deriving (Show)
+
+pTemplate :: ReadP Template
+pTemplate = do
+ items <- many (skipSpaces >> pItem)
+ skipSpaces
+ eof
+ return (Template items)
+
+pItem :: ReadP Item
+pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName)
+
+pName :: ReadP String
+pName = do
+ c1 <- satisfy (\c -> isAlpha c || c == '_')
+ cs <- munch (\c -> isAlphaNum c || c `elem` "_'")
+ return (c1:cs)
diff --git a/src/CHAD/Util/IdGen.hs b/src/CHAD/Util/IdGen.hs
new file mode 100644
index 0000000..d4fd945
--- /dev/null
+++ b/src/CHAD/Util/IdGen.hs
@@ -0,0 +1,19 @@
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+module CHAD.Util.IdGen where
+
+import Control.Monad.Fix
+import Control.Monad.Trans.State.Strict
+
+
+newtype IdGen a = IdGen (State Int a)
+ deriving newtype (Functor, Applicative, Monad, MonadFix)
+
+genId :: IdGen Int
+genId = IdGen (state (\i -> (i, i + 1)))
+
+runIdGen :: Int -> IdGen a -> a
+runIdGen start (IdGen m) = evalState m start
+
+runIdGen' :: Int -> IdGen a -> (a, Int)
+runIdGen' start (IdGen m) = runState m start