aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md5
-rw-r--r--bench/Main.hs42
-rw-r--r--chad-fast.cabal10
-rw-r--r--src/AST.hs297
-rw-r--r--src/AST/Accum.hs75
-rw-r--r--src/AST/Bindings.hs20
-rw-r--r--src/AST/Count.hs915
-rw-r--r--src/AST/Env.hs84
-rw-r--r--src/AST/Pretty.hs80
-rw-r--r--src/AST/Sparse.hs287
-rw-r--r--src/AST/Sparse/Types.hs107
-rw-r--r--src/AST/SplitLets.hs49
-rw-r--r--src/AST/Types.hs42
-rw-r--r--src/AST/UnMonoid.hs123
-rw-r--r--src/AST/Weaken.hs6
-rw-r--r--src/AST/Weaken/Auto.hs2
-rw-r--r--src/Analysis/Identity.hs48
-rw-r--r--src/Array.hs5
-rw-r--r--src/CHAD.hs1246
-rw-r--r--src/CHAD/Accum.hs57
-rw-r--r--src/CHAD/EnvDescr.hs20
-rw-r--r--src/CHAD/Top.hs57
-rw-r--r--src/CHAD/Types.hs53
-rw-r--r--src/CHAD/Types/ToTan.hs26
-rw-r--r--src/Compile.hs574
-rw-r--r--src/Compile/Exec.hs22
-rw-r--r--src/Data.hs8
-rw-r--r--src/Data/VarMap.hs4
-rw-r--r--src/Example.hs35
-rw-r--r--src/ForwardAD.hs48
-rw-r--r--src/ForwardAD/DualNumbers.hs10
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs4
-rw-r--r--src/Interpreter.hs187
-rw-r--r--src/Interpreter/Rep.hs14
-rw-r--r--src/Language.hs22
-rw-r--r--src/Language/AST.hs23
-rw-r--r--src/Simplify.hs344
-rw-r--r--src/Simplify/TH.hs2
-rw-r--r--test-framework/Test/Framework.hs383
-rw-r--r--test/Main.hs230
40 files changed, 4117 insertions, 1449 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..585e3b0
--- /dev/null
+++ b/README.md
@@ -0,0 +1,5 @@
+# chad-fast
+
+This is work-in-progress research software. If you are interested in techniques
+or ideas in this repository, please do [contact me](https://tomsmeding.com),
+I'm always happy to chat.
diff --git a/bench/Main.hs b/bench/Main.hs
index 358ba31..ec9264b 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,11 +1,13 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
@@ -17,6 +19,7 @@ import Data.Kind (Constraint)
import GHC.Exts (withDict)
import AST
+import AST.Count
import AST.UnMonoid
import Array
import qualified CHAD (defaultConfig)
@@ -34,8 +37,12 @@ import Simplify
gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))
gradCHAD config term =
compile knownEnv $
- simplifyFix $ unMonoid $ simplifyFix $
- ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term
+ simplifyFix $ pruneExpr knownEnv $
+ simplifyFix $ unMonoid $
+ simplifyFix $
+ ELet ext (EConst ext STF64 1.0) $
+ chad' config knownEnv $
+ simplifyFix term
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
@@ -93,18 +100,19 @@ makeGMMInputs =
accumConfig :: CHADConfig
accumConfig = chcSetAccum CHAD.defaultConfig
-main :: IO ()
-main = defaultMain
- [env (return makeNeuralInputs) $ \inputs -> bgroup "neural"
- [env (gradCHAD CHAD.defaultConfig neural) $ \fun ->
- bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig neural) $ \fun ->
- bench "accum" (nfAppIO fun inputs)
- ]
- ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm"
- [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun ->
+bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env))))
+ => String -> Ex env R -> SList Value env -> Benchmark
+bgroupDefaultAccum name term inputs =
+ bgroup name
+ [env (gradCHAD CHAD.defaultConfig term) $ \fun ->
bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun ->
+ ,env (gradCHAD accumConfig term) $ \fun ->
bench "accum" (nfAppIO fun inputs)
]
+
+main :: IO ()
+main = defaultMain
+ [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural
+ ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False)
+ ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil)
]
diff --git a/chad-fast.cabal b/chad-fast.cabal
index b0ed639..df0409d 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -10,6 +10,9 @@ build-type: Simple
library
exposed-modules:
+ -- default ghci module on top
+ Example
+
Analysis.Identity
Array
AST
@@ -18,6 +21,8 @@ library
AST.Count
AST.Env
AST.Pretty
+ AST.Sparse
+ AST.Sparse.Types
AST.SplitLets
AST.Types
AST.UnMonoid
@@ -33,7 +38,6 @@ library
Compile.Exec
Data
Data.VarMap
- Example
Example.GMM
Example.Types
ForwardAD
@@ -81,7 +85,11 @@ library test-framework
exposed-modules: Test.Framework
build-depends:
base,
+ ansi-terminal,
+ concurrent-output,
hedgehog,
+ pqueue,
+ stm,
time,
transformers
hs-source-dirs: test-framework
diff --git a/src/AST.hs b/src/AST.hs
index b2f5ce7..663b83f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -20,10 +20,13 @@
module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
+import Data.Functor.Identity
+import Data.Int (Int64)
import Data.Kind (Type)
import Array
import AST.Accum
+import AST.Sparse.Types
import AST.Types
import AST.Weaken
import CHAD.Types
@@ -55,20 +58,37 @@ data Expr x env t where
ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
- ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
- ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
- ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
- ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
+ -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
+
+ -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
+ -- an implementation is allowed to parallelise this thing and store the b
+ -- values in some implementation-defined order.
+ -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
+ -> Expr x (t1 : t1 : env) (TPair t1 b)
+ -> Expr x env t1
+ -> Expr x env (TArr (S n) t1)
+ -> Expr x env (TPair (TArr n t1) -- normal primal fold output
+ (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
+ -- Reverse derivative of EFold1Inner. The contributions to the initial
+ -- element are not yet added together here; we assume a later fusion system
+ -- does that for us.
+ EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
+ -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
+ -> Expr x env (TArr n t2) -- incoming cotangent
+ -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -83,6 +103,9 @@ data Expr x env t where
-- be backpropagated to; 'a' is the inactive part. The dual field of
-- ECustom does not allow a derivative to be generated for 'a', and hence
-- none is propagated.
+ -- No accumulators are allowed inside a, b and tape. This restriction is
+ -- currently not used very much, so could be relaxed in the future; be sure
+ -- to check this requirement whenever it is necessary for soundness!
ECustom :: x t -> STy a -> STy b -> STy tape
-> Expr x [b, a] t -- ^ regular operation
-> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
@@ -90,14 +113,28 @@ data Expr x env t where
-> Expr x env a -> Expr x env b
-> Expr x env t
+ -- fake halfway checkpointing
+ ERecompute :: x t -> Expr x env t -> Expr x env t
+
-- accumulation effect on monoids
+ -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
+ -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
+ -- need to create any zeros.
EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
- EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil
+ -- The 'Sparse' here is eliminated to dense by UnMonoid.
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
+ EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
- EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+
+ -- interface of abstract monoidal types
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
+ ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
@@ -200,6 +237,10 @@ typeOf = \case
EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
+
+ EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
+ EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -209,11 +250,13 @@ typeOf = \case
EOp _ op _ -> opt2 op
ECustom _ _ _ _ e _ _ _ _ -> typeOf e
+ ERecompute _ e -> typeOf e
EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ _ _ -> STNil
+ EAccum _ _ _ _ _ _ _ -> STNil
EZero _ t _ -> fromSMTy t
+ EDeepZero _ t _ -> fromSMTy t
EPlus _ t _ _ -> fromSMTy t
EOneHot _ t _ _ _ -> fromSMTy t
@@ -245,6 +288,9 @@ extOf = \case
EReplicate1Inner x _ _ -> x
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
+ EReshape x _ _ _ -> x
+ EFold1InnerD1 x _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ -> x
EConst x _ _ -> x
EIdx0 x _ -> x
EIdx1 x _ _ -> x
@@ -252,52 +298,63 @@ extOf = \case
EShape x _ -> x
EOp x _ _ -> x
ECustom x _ _ _ _ _ _ _ _ -> x
+ ERecompute x _ -> x
EWith x _ _ _ -> x
- EAccum x _ _ _ _ _ -> x
+ EAccum x _ _ _ _ _ _ -> x
EZero x _ _ -> x
+ EDeepZero x _ _ -> x
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
-mapExt f = \case
- EVar x t i -> EVar (f x) t i
- ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body)
- EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b)
- EFst x e -> EFst (f x) (mapExt f e)
- ESnd x e -> ESnd (f x) (mapExt f e)
- ENil x -> ENil (f x)
- EInl x t e -> EInl (f x) t (mapExt f e)
- EInr x t e -> EInr (f x) t (mapExt f e)
- ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b)
- ENothing x t -> ENothing (f x) t
- EJust x e -> EJust (f x) (mapExt f e)
- EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e)
- ELNil x t1 t2 -> ELNil (f x) t1 t2
- ELInl x t e -> ELInl (f x) t (mapExt f e)
- ELInr x t e -> ELInr (f x) t (mapExt f e)
- ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c)
- EConstArr x n t a -> EConstArr (f x) n t a
- EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b)
- EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c)
- ESum1Inner x e -> ESum1Inner (f x) (mapExt f e)
- EUnit x e -> EUnit (f x) (mapExt f e)
- EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b)
- EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e)
- EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e)
- EConst x t v -> EConst (f x) t v
- EIdx0 x e -> EIdx0 (f x) (mapExt f e)
- EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b)
- EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es)
- EShape x e -> EShape (f x) (mapExt f e)
- EOp x op e -> EOp (f x) op (mapExt f e)
- ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2)
- EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2)
- EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3)
- EZero x t e -> EZero (f x) t (mapExt f e)
- EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b)
- EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b)
- EError x t s -> EError (f x) t s
+mapExt f = runIdentity . travExt (Identity . f)
+
+{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
+travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt f = \case
+ EVar x t i -> EVar <$> f x <*> pure t <*> pure i
+ ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
+ EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b
+ EFst x e -> EFst <$> f x <*> travExt f e
+ ESnd x e -> ESnd <$> f x <*> travExt f e
+ ENil x -> ENil <$> f x
+ EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e
+ EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e
+ ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b
+ ENothing x t -> ENothing <$> f x <*> pure t
+ EJust x e -> EJust <$> f x <*> travExt f e
+ EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e
+ ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2
+ ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e
+ ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e
+ ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
+ EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
+ EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
+ EUnit x e -> EUnit <$> f x <*> travExt f e
+ EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
+ EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
+ EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ EConst x t v -> EConst <$> f x <*> pure t <*> pure v
+ EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
+ EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
+ EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es
+ EShape x e -> EShape <$> f x <*> travExt f e
+ EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
+ ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
+ ERecompute x e -> ERecompute <$> f x <*> travExt f e
+ EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
+ EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3
+ EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
+ EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e
+ EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
+ EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
+ EError x t s -> EError <$> f x <*> pure t <*> pure s
substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
substInline repl =
@@ -342,6 +399,9 @@ subst' f w = \case
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
+ EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
@@ -349,9 +409,11 @@ subst' f w = \case
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
+ ERecompute x e -> ERecompute x (subst' f w e)
EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3)
EZero x t e -> EZero x t (subst' f w e)
+ EDeepZero x t e -> EDeepZero x t (subst' f w e)
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
@@ -376,11 +438,11 @@ class KnownTy t where knownTy :: STy t
instance KnownTy TNil where knownTy = STNil
instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
-instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
class KnownMTy t where knownMTy :: SMTy t
instance KnownMTy TNil where knownMTy = SMTNil
@@ -398,11 +460,11 @@ styKnown :: STy t -> Dict (KnownTy t)
styKnown STNil = Dict
styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- smtyKnown t = Dict
-styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
smtyKnown :: SMTy t -> Dict (KnownMTy t)
smtyKnown SMTNil = Dict
@@ -423,6 +485,16 @@ envKnown :: SList STy env -> Dict (KnownEnv env)
envKnown SNil = Dict
envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
+cheapExpr :: Expr x env t -> Bool
+cheapExpr = \case
+ EVar{} -> True
+ ENil{} -> True
+ EConst{} -> True
+ EFst _ e -> cheapExpr e
+ ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
+ _ -> False
+
eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup = mkTup (ENil ext) (EPair ext)
@@ -447,27 +519,30 @@ eidxEq (SS n) a b
(eidxEq n (EFst ext (EVar ext ty (IS IZ)))
(EFst ext (EVar ext ty IZ)))
-emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
-emap f arr =
- let STArr n t = typeOf arr
- in ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
+emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
+emap f arr
+ | STArr n t <- typeOf arr
+ , Dict <- styKnown t
+ = ELet ext arr $
+ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) f
-ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
-ezipWith f arr1 arr2 =
- let STArr n t1 = typeOf arr1
- STArr _ t2 = typeOf arr2
- in ELet ext arr1 $
- ELet ext (weakenExpr WSink arr2) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
- weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
+ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
+ezipWith f arr1 arr2
+ | STArr n t1 <- typeOf arr1
+ , STArr _ t2 <- typeOf arr2
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELet ext arr1 $
+ ELet ext (weakenExpr WSink arr2) $
+ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
+ weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
ezip arr1 arr2 =
@@ -489,11 +564,91 @@ eshapeEmpty (SS n) e =
(EConst ext STI64 0)))
(eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
-ezeroD2 :: STy t -> Ex env (D2 t)
-ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext)
+eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
+eshapeConst ShNil = ENil ext
+eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
+
+eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx
+eshapeProd SZ _ = EConst ext STI64 1
+eshapeProd (SS SZ) e = ESnd ext e
+eshapeProd (SS n) e =
+ eunPair e $ \_ e1 e2 ->
+ EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2)
+
+eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t)
+eflatten e =
+ let STArr n _ = typeOf e
+ in elet e $
+ EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ)
+
+-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
+-- ezeroD2 t ezi = EZero ext (d2M t) ezi
-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev
+
+eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
+eunPair (EPair _ e1 e2) k = k WId e1 e2
+eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e)
+eunPair e k =
+ elet e $
+ k WSink
+ (EFst ext (evar IZ))
+ (ESnd ext (evar IZ))
+
+efst :: Ex env (TPair a b) -> Ex env a
+efst (EPair _ e1 _) = e1
+efst e = EFst ext e
+
+esnd :: Ex env (TPair a b) -> Ex env b
+esnd (EPair _ _ e2) = e2
+esnd e = ESnd ext e
+
+elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
+elet rhs body
+ | Dict <- styKnown (typeOf rhs)
+ = if cheapExpr rhs
+ then substInline rhs body
+ else ELet ext rhs body
+
+-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
+emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
+emaybe e a b
+ | STMaybe t <- typeOf e
+ , Dict <- styKnown t
+ = EMaybe ext a b e
+
+ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+ecase e a b
+ | STEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ECase ext e a b
+
+elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+elcase e a b c
+ | STLEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELCase ext e a b c
+
+evar :: KnownTy a => Idx env a -> Ex env a
+evar = EVar ext knownTy
+
+makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
+ where
+ -- invariant: expression argument is duplicable
+ go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+ go SMTNil _ = ENil ext
+ go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
+ go SMTLEither{} _ = ENil ext
+ go SMTMaybe{} _ = ENil ext
+ go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
+ go SMTScal{} _ = ENil ext
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index e84034b..988a450 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,14 +1,13 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
import AST.Types
-import CHAD.Types
import Data
@@ -35,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type family AcIdx p t where
- AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
- AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
- AcIdx (APLeft p) (TLEither a b) = AcIdx p a
- AcIdx (APRight p) (TLEither a b) = AcIdx p b
- AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) =
- -- ((index, shapes info), recursive info)
+type data AIDense = AID | AIS
+
+data SAIDense d where
+ SAID :: SAIDense AID
+ SAIS :: SAIDense AIS
+deriving instance Show (SAIDense d)
+
+type family AcIdx d p t where
+ AcIdx d APHere t = TNil
+ AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
+ AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
+ AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
+ AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
+ AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
+ AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
+ AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
+ AcIdx AID (APArrIdx p) (TArr n a) =
+ -- (index, recursive info)
+ TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
+ AcIdx AIS (APArrIdx p) (TArr n a) =
+ -- ((index, shape info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx p a)
- -- AcIdx (APArrSlice m) (TArr n a) =
+ (AcIdx AIS p a)
+ -- AcIdx AID (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AIS (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+type AcIdxD p t = AcIdx AID p t
+type AcIdxS p t = AcIdx AIS p t
+
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
@@ -75,19 +92,23 @@ tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
-lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil
-lemZeroInfoD2 STNil = Refl
-lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl
-lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl
-lemZeroInfoD2 (STScal STI32) = Refl
-lemZeroInfoD2 (STScal STI64) = Refl
-lemZeroInfoD2 (STScal STF32) = Refl
-lemZeroInfoD2 (STScal STF64) = Refl
-lemZeroInfoD2 (STScal STBool) = Refl
-lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program"
-lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Should be constructable from a D1.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 3d99afe..463586a 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -16,6 +16,7 @@
module AST.Bindings where
import AST
+import AST.Env
import Data
import Lemmas
@@ -27,6 +28,10 @@ data Bindings f env binds where
deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
infixl `BPush`
+bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds)
+bpush b e = b `BPush` (typeOf e, e)
+infixl `bpush`
+
mapBindings :: (forall env' t'. f env' t' -> g env' t')
-> Bindings f env binds -> Bindings g env binds
mapBindings _ BTop = BTop
@@ -41,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) =
let (b', w') = weakenBindings wf w b
in (BPush b' (t, wf w' x), WCopy w')
+weakenBindingsE :: env1 :> env2
+ -> Bindings (Expr x) env1 binds
+ -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2)
+weakenBindingsE = weakenBindings weakenExpr
+
weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
weakenOver SNil w = w
weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
@@ -62,3 +72,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
+collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
+collectBindings = \env -> fst . go env WId
+ where
+ go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
+ go _ _ SETop = (BTop, WId)
+ go (ty `SCons` env) w (SEYesR sub) =
+ let (bs, w') = go env (WPop w) sub
+ in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
+ go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index feaaa1e..296c021 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
@@ -10,17 +11,31 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE PatternSynonyms #-}
module AST.Count where
-import Data.Functor.Const
+import Data.Functor.Product
+import Data.Some
+import Data.Type.Equality
import GHC.Generics (Generic, Generically(..))
+import Array
import AST
import AST.Env
import Data
+-- | The monoid operation combines assuming that /both/ branches are taken.
+class Monoid a => Occurrence a where
+ -- | One of the two branches is taken
+ (<||>) :: a -> a -> a
+ -- | This code is executed many times
+ scaleMany :: a -> a
+
+
data Count = Zero | One | Many
deriving (Show, Eq, Ord)
@@ -30,6 +45,10 @@ instance Semigroup Count where
_ <> _ = Many
instance Monoid Count where
mempty = Zero
+instance Occurrence Count where
+ (<||>) = max
+ scaleMany Zero = Zero
+ scaleMany _ = Many
data Occ = Occ { _occLexical :: Count
, _occRuntime :: Count }
@@ -40,120 +59,818 @@ instance Show Occ where
showsPrec d (Occ l r) = showParen (d > 10) $
showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
--- | One of the two branches is taken
-(<||>) :: Occ -> Occ -> Occ
-Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+instance Occurrence Occ where
+ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2)
+ scaleMany (Occ l c) = Occ l (scaleMany c)
+
+
+data Substruc t t' where
+ -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below
+ SsFull :: Substruc t t
+ SsNone :: Substruc t TNil
+ SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b')
+ SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b')
+ SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b')
+ SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a')
+ SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements
+ SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a')
+
+pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t'
+pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2)
+ where SsPair' = SsPair
+{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-}
+
+pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t'
+pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s)
+ where SsArr' = SsArr
+{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-}
+
+instance Semigroup (Some (Substruc t)) where
+ Some SsFull <> _ = Some SsFull
+ _ <> Some SsFull = Some SsFull
+ Some SsNone <> s = s
+ s <> Some SsNone = s
+ Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2)
+ Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2)
+ Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2)
+ Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2)
+ Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2)
+ Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2)
+instance Monoid (Some (Substruc t)) where
+ mempty = Some SsNone
+
+instance TestEquality (Substruc t) where
+ testEquality SsFull s = isFull s
+ testEquality s SsFull = sym <$> isFull s
+ testEquality SsNone SsNone = Just Refl
+ testEquality SsNone _ = Nothing
+ testEquality _ SsNone = Nothing
+ testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+
+isFull :: Substruc t t' -> Maybe (t :~: t')
+isFull SsFull = Just Refl
+isFull SsNone = Nothing -- TODO: nil?
+isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+
+applySubstruc :: Substruc t t' -> STy t -> STy t'
+applySubstruc SsFull t = t
+applySubstruc SsNone _ = STNil
+applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t)
+applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t)
+applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t)
+
+applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t'
+applySubstrucM SsFull t = t
+applySubstrucM SsNone _ = SMTNil
+applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t)
+applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t)
+applySubstrucM _ t = case t of {}
+
+data ExMap a b = ExMap (forall env. Ex env a -> Ex env b)
+ | a ~ b => ExMapId
+
+fromExMap :: ExMap a b -> Ex env a -> Ex env b
+fromExMap (ExMap f) = f
+fromExMap ExMapId = id
+
+simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t'
+simplifySubstruc STNil SsNone = SsFull
+
+simplifySubstruc _ SsFull = SsFull
+simplifySubstruc _ SsNone = SsNone
+simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s)
+simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s)
+simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s)
+
+-- simplifySubstruc' :: Substruc t t'
+-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r
+-- simplifySubstruc' SsFull k = k SsFull ExMapId
+-- simplifySubstruc' SsNone k = k SsNone ExMapId
+-- simplifySubstruc' (SsPair s1 s2) k =
+-- simplifySubstruc' s1 $ \s1' f1 ->
+-- simplifySubstruc' s2 $ \s2' f2 ->
+-- case (s1', s2') of
+-- (SsFull, SsFull) ->
+-- k SsFull (case (f1, f2) of
+-- (ExMapId, ExMapId) -> ExMapId
+-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 ->
+-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2)))
+-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext))))
+-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ)))))
+-- simplifySubstruc' _ _ = _
+
+-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b)
+-- ssUnpair SsFull = (SsFull, SsFull)
+-- ssUnpair SsNone = (SsNone, SsNone)
+-- ssUnpair (SsPair a b) = (a, b)
+
+-- ssUnleft :: Substruc (TEither a b) -> Substruc a
+-- ssUnleft SsFull = SsFull
+-- ssUnleft SsNone = SsNone
+-- ssUnleft (SsEither a _) = a
+
+-- ssUnright :: Substruc (TEither a b) -> Substruc b
+-- ssUnright SsFull = SsFull
+-- ssUnright SsNone = SsNone
+-- ssUnright (SsEither _ b) = b
+
+-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a
+-- ssUnlleft SsFull = SsFull
+-- ssUnlleft SsNone = SsNone
+-- ssUnlleft (SsLEither a _) = a
+
+-- ssUnlright :: Substruc (TLEither a b) -> Substruc b
+-- ssUnlright SsFull = SsFull
+-- ssUnlright SsNone = SsNone
+-- ssUnlright (SsLEither _ b) = b
+
+-- ssUnjust :: Substruc (TMaybe a) -> Substruc a
+-- ssUnjust SsFull = SsFull
+-- ssUnjust SsNone = SsNone
+-- ssUnjust (SsMaybe a) = a
+
+-- ssUnarr :: Substruc (TArr n a) -> Substruc a
+-- ssUnarr SsFull = SsFull
+-- ssUnarr SsNone = SsNone
+-- ssUnarr (SsArr a) = a
+
+-- ssUnaccum :: Substruc (TAccum a) -> Substruc a
+-- ssUnaccum SsFull = SsFull
+-- ssUnaccum SsNone = SsNone
+-- ssUnaccum (SsAccum a) = a
+
+
+type family MapEmpty env where
+ MapEmpty '[] = '[]
+ MapEmpty (t : env) = TNil : MapEmpty env
+
+data OccEnv a env env' where
+ OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top!
+ OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env')
+
+instance Semigroup a => Semigroup (Some (OccEnv a env)) where
+ Some OccEnd <> e = e
+ e <> Some OccEnd = e
+ Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2)
+
+instance Semigroup a => Monoid (Some (OccEnv a env)) where
+ mempty = Some OccEnd
+
+instance Occurrence a => Occurrence (Some (OccEnv a env)) where
+ Some OccEnd <||> e = e
+ e <||> Some OccEnd = e
+ Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2)
+
+ scaleMany (Some OccEnd) = Some OccEnd
+ scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s)
+
+onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env)
+onehotOccEnv IZ v s = Some (OccPush OccEnd v s)
+onehotOccEnv (IS i) v s
+ | Some env' <- onehotOccEnv i v s
+ = Some (OccPush env' mempty SsNone)
+
+occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t')
+occEnvPop (OccPush e _ s) = (e, s)
+occEnvPop OccEnd = (OccEnd, SsNone)
--- | This code is executed many times
-scaleMany :: Occ -> Occ
-scaleMany (Occ l Zero) = Occ l Zero
-scaleMany (Occ l _) = Occ l Many
+occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r
+occEnvPop' (OccPush e _ s) k = k e s
+occEnvPop' OccEnd k = k OccEnd SsNone
+
+occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env)
+occEnvPopSome (Some (OccPush e _ _)) = Some e
+occEnvPopSome (Some OccEnd) = Some OccEnd
+
+occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t))
+occEnvPrj OccEnd _ = mempty
+occEnvPrj (OccPush _ o s) IZ = (o, Some s)
+occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i
+
+occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env'))
+occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ)
+occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i'))
+occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ)
+occEnvPrjS (OccPush e _ _) (IS i)
+ | Some (Pair s' i') <- occEnvPrjS e i
+ = Some (Pair s' (IS i'))
+
+projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small
+projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
+ _ | Just Refl <- testEquality topsbig topssmall -> ex
+
+ (SsFull, SsFull) -> ex
+ (SsNone, SsNone) -> ex
+ (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller"
+ (_, SsNone) ->
+ case typeOf ex of
+ STNil -> ex
+ _ -> use ex $ ENil ext
+
+ (SsPair s1 s2, SsPair s1' s2') ->
+ eunPair ex $ \_ e1 e2 ->
+ EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2)
+ (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex
+ (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex
+
+ (SsEither s1 s2, SsEither s1' s2')
+ | STEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in ecase ex
+ (EInl ext (typeOf e2) e1)
+ (EInr ext (typeOf e1) e2)
+ (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex
+ (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex
+
+ (SsLEither s1 s2, SsLEither s1' s2')
+ | STLEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in elcase ex
+ (ELNil ext (typeOf e1) (typeOf e2))
+ (ELInl ext (typeOf e2) e1)
+ (ELInr ext (typeOf e1) e2)
+ (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex
+ (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex
+
+ (SsMaybe s1, SsMaybe s1')
+ | STMaybe t1 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ in emaybe ex
+ (ENothing ext (typeOf e1))
+ (EJust ext e1)
+ (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
+ (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
+
+ (SsArr s1, SsArr s2)
+ | STArr n t <- typeOf ex ->
+ elet ex $
+ EBuild ext n (EShape ext (evar IZ)) $
+ projectSmallerSubstruc s1 s2
+ (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
+ (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
+
+ (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum"
+ (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex
+ (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex
+
+
+-- | A boolean for each entry in the environment, with the ability to uniformly
+-- mask the top part above a certain index.
+data EnvMask env where
+ EMRest :: Bool -> EnvMask env
+ EMPush :: EnvMask env -> Bool -> EnvMask (t : env)
+
+envMaskPrj :: EnvMask env -> Idx env t -> Bool
+envMaskPrj (EMRest b) _ = b
+envMaskPrj (_ `EMPush` b) IZ = b
+envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i
occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx =
- getConst . occCountGeneral
- (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
- (\(Const o) -> Const o)
- (\(Const o1) (Const o2) -> Const (o1 <||> o2))
- (\(Const o) -> Const (scaleMany o))
+occCount idx ex
+ | Some env <- occCountAll ex
+ = fst (occEnvPrj env idx)
+
+occCountAll :: Expr x env t -> Some (OccEnv Occ env)
+occCountAll ex = occCountX SsFull ex $ \env _ -> Some env
+
+pruneExpr :: SList f env -> Expr x env t -> Ex env t
+pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
+ where
+ fullOccEnv :: SList f env -> OccEnv () env env
+ fullOccEnv SNil = OccEnd
+ fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
+
+-- In one traversal, count occurrences of variables and determine what parts of
+-- expressions are actually used. These two results are computed independently:
+-- even if (almost) nothing of a particular term is actually used, variable
+-- references in that term still count as usual.
+--
+-- In @occCountX s t k@:
+-- * s: how much of the result of this term is required
+-- * t: the term to analyse
+-- * k: is passed the actual environment usage of this expression, including
+-- occurrence counts. The callback reconstructs a new expression in an
+-- updated "response" environment. The response must be at least as large as
+-- the computed usages.
+occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t
+ -> (forall env'. OccEnv Occ env env'
+ -- response OccEnv must be at least as large as the OccEnv returned above
+ -> (forall env''. OccEnv () env env'' -> Ex env'' t')
+ -> r)
+ -> r
+occCountX initialS topexpr k = case topexpr of
+ EVar _ t i ->
+ withSome (onehotOccEnv i (Occ One One) s) $ \env ->
+ k env $ \env' ->
+ withSome (occEnvPrjS env' i) $ \(Pair s' i') ->
+ projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i')
+ ELet _ rhs body ->
+ occCountX s body $ \envB mkbody ->
+ occEnvPop' envB $ \envB' s1 ->
+ occCountX s1 rhs $ \envR mkrhs ->
+ withSome (Some envB' <> Some envR) $ \env ->
+ k env $ \env' ->
+ ELet ext (mkrhs env') (mkbody (OccPush env' () s1))
+ EPair _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsPair' s1 s2 ->
+ occCountX s1 a $ \env1 mka ->
+ occCountX s2 b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPair ext (mka env') (mkb env')
+ EFst _ e ->
+ occCountX (SsPair s SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EFst ext (mke env')
+ ESnd _ e ->
+ occCountX (SsPair SsNone s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ ESnd ext (mke env')
+ ENil _ ->
+ case s of
+ SsFull -> k OccEnd (\_ -> ENil ext)
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ EInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ EInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ ECase _ e a b ->
+ occCountX s a $ \env1' mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env1' $ \env1 s1 ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
+ k env $ \env' ->
+ ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
+ ENothing _ t ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t))
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EJust _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsMaybe s' ->
+ occCountX s' e $ \env1 mke ->
+ k env1 $ \env' ->
+ EJust ext (mke env')
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EMaybe _ a b e ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsMaybe s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
+ k env $ \env' ->
+ EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
+ ELNil _ t1 t2 ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2))
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELCase _ e a b c ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occCountX s c $ \env3' mkc ->
+ occEnvPop' env2' $ \env2 s1 ->
+ occEnvPop' env3' $ \env3 s2 ->
+ occCountX (SsLEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env ->
+ k env $ \env' ->
+ ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
+
+ EConstArr _ n t x ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
+
+ EBuild _ n a b ->
+ case s of
+ SsNone ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsNone b $ \env2'' mkb ->
+ withSome (scaleMany (Some env2'')) $ \env2' ->
+ occEnvPop' env2' $ \env2 s2 ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EBuild ext n (mka env') $
+ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
+ ENil ext) $
+ ENil ext
+ SsArr' s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX s' b $ \env2'' mkb ->
+ withSome (scaleMany (Some env2'')) $ \env2' ->
+ occEnvPop' env2' $ \env2 s2 ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EBuild ext n (mka env') $
+ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+
+ EFold1Inner _ commut a b c ->
+ occCountX SsFull a $ \env1''' mka ->
+ withSome (scaleMany (Some env1''')) $ \env1'' ->
+ occEnvPop' env1'' $ \env1' s2 ->
+ occEnvPop' env1' $ \env1 s1 ->
+ let s0 = case s of
+ SsNone -> Some SsNone
+ SsArr' s' -> Some s' in
+ withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ occCountX sElt b $ \env2 mkb ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsArr sElt) s $
+ EFold1Inner ext commut
+ (projectSmallerSubstruc SsFull sElt $
+ mka (OccPush (OccPush env' () sElt) () sElt))
+ (mkb env') (mkc env')
+
+ ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
+
+ EUnit _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX s' e $ \env mke ->
+ k env $ \env' ->
+ EUnit ext (mke env')
+
+ EReplicate1Inner _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX (SsArr s') b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReplicate1Inner ext (mka env') (mkb env')
+
+ EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+
+ EReshape _ n esh e ->
+ case s of
+ SsNone ->
+ occCountX SsNone esh $ \env1 mkesh ->
+ occCountX SsNone e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkesh env') $ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull esh $ \env1 mkesh ->
+ occCountX (SsArr s') e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReshape ext n (mkesh env') (mke env')
+
+ EFold1InnerD1 _ cm e1 e2 e3 ->
+ case s of
+ -- If nothing is necessary, we can execute a fold and then proceed to ignore it
+ SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex
+ -- If we don't need the stores, still a fold suffices
+ SsPair' sP SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext)
+ -- If for whatever reason the additional stores themselves are
+ -- unnecessary but the shape of the array is, then oblige
+ SsPair' sP (SsArr' SsNone) ->
+ let STArr sn _ = typeOf e3
+ foldex =
+ elet (mapExt (\_ -> ext) e3) $
+ EPair ext
+ (EShape ext (evar IZ))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (mapExt (\_ -> ext) (weakenExpr WSink e2))
+ (evar IZ))
+ in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
+ k env1 $ \env' ->
+ eunPair (mkfoldex env') $ \_ eshape earr ->
+ EPair ext earr (EBuild ext sn eshape (ENil ext))
+ -- If at least some of the additional stores are required, we need to keep this a mapAccum
+ SsPair' _ (SsArr' sB) ->
+ -- TODO: propagate usage of primals
+ occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' _ ->
+ occCountX SsFull e2 $ \env2 mkb ->
+ occCountX SsFull e3 $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
+ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ (mkb env') (mkc env')
+
+ EFold1InnerD2 _ cm ef ebog ed ->
+ -- TODO: propagate usage of duals
+ occCountX SsFull ef $ \env1_2' mkef ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' sB ->
+ occCountX (SsArr sB) ebog $ \env2 mkebog ->
+ occCountX SsFull ed $ \env3 mked ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD2 ext cm
+ (mkef (OccPush (OccPush env' () sB) () SsFull))
+ (mkebog env') (mked env')
+
+ EConst _ t x ->
+ k OccEnd $ \_ ->
+ case s of
+ SsNone -> ENil ext
+ SsFull -> EConst ext t x
+
+ EIdx0 _ e ->
+ occCountX (SsArr s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EIdx0 ext (mke env')
+
+ EIdx1 _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' s' ->
+ occCountX (SsArr s') a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx1 ext (mka env') (mkb env')
+
+ EIdx _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ _ ->
+ occCountX (SsArr s) a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx ext (mka env') (mkb env')
+
+ EShape _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX (SsArr SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EShape ext (mke env')
+ EOp _ op e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
-data OccEnv env where
- OccEnd :: OccEnv env -- not necessarily top!
- OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
+ ECustom _ t1 t2 t3 e1 e2 e3 a b
+ | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
+ error "Accumulators not allowed in input/output/tape of an ECustom"
+ | otherwise ->
+ case s of
+ SsNone ->
+ -- Allowed to ignore e1/e2/e3 here because no accumulators are
+ -- communicated, and hence no relevant effects exist
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ s' -> -- Let's be pessimistic for safety
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s' $
+ ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
-instance Semigroup (OccEnv env) where
- OccEnd <> e = e
- e <> OccEnd = e
- OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
+ ERecompute _ e ->
+ occCountX s e $ \env1 mke ->
+ k env1 $ \env' ->
+ ERecompute ext (mke env')
-instance Monoid (OccEnv env) where
- mempty = OccEnd
+ EWith _ t a b ->
+ case s of
+ SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with
+ occCountX SsNone b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ withSome (case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone) $ \s1' ->
+ occCountX s1' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $
+ ENil ext
+ SsPair sB sA ->
+ occCountX sB b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ let s1' = case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone in
+ withSome (Some sA <> s1') $ \sA' ->
+ occCountX sA' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $
+ EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA')))
+ SsFull -> occCountX (SsPair SsFull SsFull) topexpr k
-onehotOccEnv :: Idx env t -> Occ -> OccEnv env
-onehotOccEnv IZ v = OccPush OccEnd v
-onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
+ EAccum _ t p a sp b e ->
+ -- TODO: do better!
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ occCountX SsFull e $ \env3 mke ->
+ withSome (Some env1 <> Some env2) $ \env12 ->
+ withSome (Some env12 <> Some env3) $ \env ->
+ k env $ \env' ->
+ case s of {SsFull -> id; SsNone -> id} $
+ EAccum ext t p (mka env') sp (mkb env') (mke env')
-(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
-OccEnd <||>! e = e
-e <||>! OccEnd = e
-OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
+ EZero _ t e ->
+ occCountX (subZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EZero ext (applySubstrucM s t) (mke env')
+ where
+ subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2)
+ subZeroInfo SsFull = SsFull
+ subZeroInfo SsNone = SsNone
+ subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2)
+ subZeroInfo SsEither{} = error "Either is not a monoid"
+ subZeroInfo SsLEither{} = SsNone
+ subZeroInfo SsMaybe{} = SsNone
+ subZeroInfo (SsArr s') = SsArr (subZeroInfo s')
+ subZeroInfo SsAccum{} = error "Accum is not a monoid"
-scaleManyOccEnv :: OccEnv env -> OccEnv env
-scaleManyOccEnv OccEnd = OccEnd
-scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
+ EDeepZero _ t e ->
+ occCountX (subDeepZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EDeepZero ext (applySubstrucM s t) (mke env')
+ where
+ subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2)
+ subDeepZeroInfo SsFull = SsFull
+ subDeepZeroInfo SsNone = SsNone
+ subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo SsEither{} = error "Either is not a monoid"
+ subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s')
+ subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s')
+ subDeepZeroInfo SsAccum{} = error "Accum is not a monoid"
-occEnvPop :: OccEnv (t : env) -> OccEnv env
-occEnvPop (OccPush o _) = o
-occEnvPop OccEnd = OccEnd
+ EPlus _ t a b ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPlus ext (applySubstrucM s t) (mka env') (mkb env')
-occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
+ EOneHot _ t p a b ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb -> -- TODO: do better
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env')
-occCountGeneral :: forall r env t x.
- (forall env'. Monoid (r env'))
- => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
- -> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env'. r env' -> r env' -> r env') -- ^ alternation
- -> (forall env'. r env' -> r env') -- ^ scale-many
- -> Expr x env t -> r env
-occCountGeneral onehot unpush alter many = go WId
+ EError _ t msg ->
+ k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
where
- go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
- go w = \case
- EVar _ _ i -> onehot w i (Occ One One)
- ELet _ rhs body -> re rhs <> re1 body
- EPair _ a b -> re a <> re b
- EFst _ e -> re e
- ESnd _ e -> re e
- ENil _ -> mempty
- EInl _ _ e -> re e
- EInr _ _ e -> re e
- ECase _ e a b -> re e <> (re1 a `alter` re1 b)
- ENothing _ _ -> mempty
- EJust _ e -> re e
- EMaybe _ a b e -> re a <> re1 b <> re e
- ELNil _ _ _ -> mempty
- ELInl _ _ e -> re e
- ELInr _ _ e -> re e
- ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)
- EConstArr{} -> mempty
- EBuild _ _ a b -> re a <> many (re1 b)
- EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
- ESum1Inner _ e -> re e
- EUnit _ e -> re e
- EReplicate1Inner _ a b -> re a <> re b
- EMaximum1Inner _ e -> re e
- EMinimum1Inner _ e -> re e
- EConst{} -> mempty
- EIdx0 _ e -> re e
- EIdx1 _ a b -> re a <> re b
- EIdx _ a b -> re a <> re b
- EShape _ e -> re e
- EOp _ _ e -> re e
- ECustom _ _ _ _ _ _ _ a b -> re a <> re b
- EWith _ _ a b -> re a <> re1 b
- EAccum _ _ _ a b e -> re a <> re b <> re e
- EZero _ _ e -> re e
- EPlus _ _ a b -> re a <> re b
- EOneHot _ _ _ a b -> re a <> re b
- EError{} -> mempty
- where
- re :: Monoid (r env') => Expr x env' t'' -> r env'
- re = go w
+ s = simplifySubstruc (typeOf topexpr) initialS
- re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
- re1 = unpush . go (WSink .> w)
+ handleReduction :: t ~ TArr n (TScal t2)
+ => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
+ -> Expr x env (TArr (S n) (TScal t2))
+ -> r
+ handleReduction reduce e
+ | STArr (SS n) _ <- typeOf e =
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr' SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ reduce (mke env')
-deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
-deleteUnused SNil OccEnd k = k SETop
-deleteUnused (_ `SCons` env) OccEnd k =
- deleteUnused env OccEnd $ \sub -> k (SENo sub)
-deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
- deleteUnused env occenv $ \sub ->
+deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
+deleteUnused SNil (Some OccEnd) k = k SETop
+deleteUnused (_ `SCons` env) (Some OccEnd) k =
+ deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub)
+deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k =
+ deleteUnused env (Some occenv) $ \sub ->
case count of Zero -> k (SENo sub)
- _ -> k (SEYes sub)
+ _ -> k (SEYesR sub)
unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
@@ -162,7 +879,7 @@ unsafeWeakenWithSubenv = \sub ->
Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
where
sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkViaSubenv IZ (SEYes _) = Just IZ
+ sinkViaSubenv IZ (SEYesR _) = Just IZ
sinkViaSubenv IZ (SENo _) = Nothing
- sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
+ sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
index 4f34166..85faba3 100644
--- a/src/AST/Env.hs
+++ b/src/AST/Env.hs
@@ -1,59 +1,95 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Env where
+import Data.Type.Equality
+
+import AST.Sparse
import AST.Weaken
+import CHAD.Types
import Data
-- | @env'@ is a subset of @env@: each element of @env@ is either included in
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv env env' where
- SETop :: Subenv '[] '[]
- SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env'
-deriving instance Show (Subenv env env')
+data Subenv' s env env' where
+ SETop :: Subenv' s '[] '[]
+ SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env')
+ SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env'
+deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env')
+
+type Subenv = Subenv' (:~:)
+type SubenvS = Subenv' Sparse
+
+pattern SEYesR :: forall tenv tenv'. ()
+ => forall t env env'. (tenv ~ t : env, tenv' ~ t : env')
+ => Subenv env env' -> Subenv tenv tenv'
+pattern SEYesR s = SEYes Refl s
+
+{-# COMPLETE SETop, SEYesR, SENo #-}
-subList :: SList f env -> Subenv env env' -> SList f env'
+subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'
subList SNil SETop = SNil
-subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
+subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)
subList (SCons _ xs) (SENo sub) = subList xs sub
-subenvAll :: SList f env -> Subenv env env
+subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env
subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
+subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env)
-subenvNone :: SList f env -> Subenv env '[]
+subenvNone :: SList f env -> Subenv' s env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
-subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
-subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
-subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
-subenvOnehot SNil i = case i of {}
+subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t']
+subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
+subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
+subenvOnehot SNil i _ = case i of {}
-subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3
+subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
subenvCompose SETop SETop = SETop
-subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)
-subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
+subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
-subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1')
+subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')
subenvConcat sub1 SETop = sub1
-subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2)
+subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)
subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
-sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0
+-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2
+-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r
+-- subenvSplit SNil sub k = k SETop sub
+-- subenvSplit (SCons _ list) (SENo sub) k =
+-- subenvSplit list sub $ \sub1 sub2 ->
+-- k (SENo sub1) sub2
+-- subenvSplit (SCons _ list) (SEYes s sub) k =
+-- subenvSplit list sub $ \sub1 sub2 ->
+-- k (SEYes s sub1) sub2
+
+sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0
sinkWithSubenv SETop = WId
-sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub
+sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub
sinkWithSubenv (SENo sub) = sinkWithSubenv sub
-wUndoSubenv :: Subenv env env' -> env' :> env
+wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env
wUndoSubenv SETop = WId
-wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)
wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
+
+subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env'
+subenvMap _ SNil SETop = SETop
+subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub)
+subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub)
+
+subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env')
+subenvD2E SETop = SETop
+subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub)
+subenvD2E (SENo sub) = SENo (subenvD2E sub)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index e09f3ae..68fc629 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -6,6 +6,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
@@ -25,6 +26,7 @@ import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -70,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex
_ -> return "_"
| otherwise = genName' prefix
+-- TODO: let this return a type-tagged thing so that name environments are more typed than Const
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
@@ -145,12 +148,20 @@ ppExpr' d val expr = case expr of
EMaybe _ a b e -> do
let STMaybe t = typeOf e
- a' <- ppExpr' 11 val a
+ e' <- ppExpr' 0 val e
+ a' <- ppExpr' 0 val a
name <- genNameIfUsedIn t IZ b
b' <- ppExpr' 0 (Const name `SCons` val) b
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $
- ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e']
+ return $ ppParen (d > 0) $
+ align $
+ group (flatAlt
+ (annotate AKey (ppString "case") <> ppX expr <+> e'
+ <> hardline <> annotate AKey (ppString "of"))
+ (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")))
+ <> hardline
+ <> indent 2
+ (ppString "Nothing" <+> ppString "->" <+> a'
+ <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b')
ELNil _ _ _ -> return (ppString "LNil")
@@ -199,8 +210,7 @@ ppExpr' d val expr = case expr of
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
- let opname = case cm of Commut -> "fold1i(C)"
- Noncommut -> "fold1i"
+ let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
@@ -225,6 +235,34 @@ ppExpr' d val expr = case expr of
e' <- ppExpr' 11 val e
return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+ EReshape _ n esh e -> do
+ esh' <- ppExpr' 11 val esh
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e'
+
+ EFold1InnerD1 _ cm a b c -> do
+ name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
+ name2 <- genNameIfUsedIn (typeOf b) IZ a
+ a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
+ let opname = "fold1iD1" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+
+ EFold1InnerD2 _ cm ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ namef1 <- genNameIfUsedIn tB (IS IZ) ef
+ namef2 <- genNameIfUsedIn t2 IZ ef
+ ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
+ ebog' <- ppExpr' 11 val ebog
+ ed' <- ppExpr' 11 val ed
+ let opname = "fold1iD2" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr)
+ [ppLam [ppString namef1, ppString namef2] ef', ebog', ed']
+
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -280,6 +318,10 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
@@ -292,18 +334,24 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
+ [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
EZero _ t e1 -> do
e1' <- ppExpr' 11 val e1
return $ ppParen (d > 0) $
annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+ EDeepZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
EPlus _ t a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
@@ -356,6 +404,20 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
+ppSparse :: SMTy a -> Sparse a b -> String
+ppSparse t sp | Just Refl <- isDense t sp = "D"
+ppSparse _ SpAbsent = "A"
+ppSparse t (SpSparse s) = "S" ++ ppSparse t s
+ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
+ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
+ppSparse (SMTScal _) SpScal = "."
+
+ppCommut :: Commutative -> String
+ppCommut Commut = "(C)"
+ppCommut Noncommut = ""
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -388,6 +450,7 @@ ppSTy' :: Int -> STy t -> Doc q
ppSTy' _ STNil = ppString "1"
ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
ppSTy' d (STArr n t) = ppParen (d > 10) $
ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
@@ -398,7 +461,6 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF64 -> "f64"
STBool -> "bool"
ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
-ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
ppSMTy :: Int -> SMTy t -> String
ppSMTy d ty = render $ ppSMTy' d ty
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
new file mode 100644
index 0000000..2a29799
--- /dev/null
+++ b/src/AST/Sparse.hs
@@ -0,0 +1,287 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
+module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
+
+import Data.Type.Equality
+
+import AST
+import AST.Sparse.Types
+import Data (SBool(..))
+
+
+sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
+sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
+sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
+sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
+sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
+ eunPair e1 $ \w1 e1a e1b ->
+ eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
+ EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
+ (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
+sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
+ elet e2 $
+ elcase (weakenExpr WSink e1)
+ (evar IZ)
+ (elcase (evar (IS IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
+ (elcase (evar (IS IZ))
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
+ elet e2 $
+ emaybe (weakenExpr WSink e1)
+ (evar IZ)
+ (emaybe (evar (IS IZ))
+ (EJust ext (evar IZ))
+ (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
+sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+
+
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
+
+
+data Injection sp a b where
+ -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
+ -- 'sparsePlusS' can provide injections even if the caller doesn't require
+ -- them. This simplifies the sparsePlusS code.
+ Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
+ Noinj :: Injection False a b
+
+withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
+withInj (Inj f) k = Inj (k f)
+withInj Noinj _ = Noinj
+
+withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
+ -> ((forall e. Ex e a1 -> Ex e b1)
+ -> (forall e. Ex e a2 -> Ex e b2)
+ -> (forall e'. Ex e' a' -> Ex e' b'))
+ -> Injection sp a' b'
+withInj2 (Inj f) (Inj g) k = Inj (k f g)
+withInj2 Noinj _ _ = Noinj
+withInj2 _ Noinj _ = Noinj
+
+-- | This function produces quadratically-sized code in the presence of nested
+-- dynamic sparsity. TODO can this be improved?
+sparsePlusS
+ :: SBool inj1 -> SBool inj2
+ -> SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
+ -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+-- nil override (but don't destroy effects!)
+sparsePlusS _ _ SMTNil _ _ k =
+ k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
+
+-- simplifications
+sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
+ sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
+sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
+ sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
+
+sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
+ let ta = applySparse sp1 (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
+ let tb = applySparse sp2 (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
+ let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
+ let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
+ let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
+ let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
+sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
+
+-- TODO: sparse of Just is just Maybe
+
+-- dense plus
+sparsePlusS _ _ t sp1 sp2 k
+ | Just Refl <- isDense t sp1
+ , Just Refl <- isDense t sp2
+ = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
+
+-- handle absents
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
+
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
+
+-- double sparse yields sparse
+sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- single sparse can yield non-sparse if the other argument is always present
+sparsePlusS SF _ t (SpSparse sp1) sp2 k =
+ sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
+ k sp3 Noinj (Inj inj2)
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (inj2 (evar IZ))
+ (plus (evar IZ) (evar (IS IZ))))
+sparsePlusS ST _ t (SpSparse sp1) sp2 k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> EJust ext (inj2 b))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (EJust ext (inj2 (evar IZ)))
+ (EJust ext (plus (evar IZ) (evar (IS IZ)))))
+sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
+ sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
+ k sp3 inj2 inj1 (flip plus)
+
+-- products
+sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
+ sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
+ sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
+ k (SpPair sp3a sp3b)
+ (withInj2 minj13a minj13b $ \inj13a inj13b ->
+ \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
+ (withInj2 minj23a minj23b $ \inj23a inj23b ->
+ \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
+ (\x1 x2 ->
+ eunPair x1 $ \w1 x1a x1b ->
+ eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
+ EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
+
+-- coproducts
+sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
+ sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
+ sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
+ let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
+ inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
+ inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
+ in
+ k (SpLEither sp3a sp3b)
+ (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
+ (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
+ (\x1 x2 ->
+ elet x2 $
+ elcase (weakenExpr WSink x1)
+ (elcase (evar IZ)
+ nil
+ (inl (inj23a (evar IZ)))
+ (inr (inj23b (evar IZ))))
+ (elcase (evar (IS IZ))
+ (inl (inj13a (evar IZ)))
+ (inl (plusa (evar (IS IZ)) (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
+ (elcase (evar (IS IZ))
+ (inr (inj13b (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
+ (inr (plusb (evar (IS IZ)) (evar IZ)))))
+
+-- maybe
+sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpMaybe sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- dense array cotangents simply recurse
+sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
+ k (SpArr sp3)
+ (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
+ (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+ (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
+ (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+
+-- scalars
+sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs
new file mode 100644
index 0000000..10cac4e
--- /dev/null
+++ b/src/AST/Sparse/Types.hs
@@ -0,0 +1,107 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module AST.Sparse.Types where
+
+import AST.Types
+
+import Data.Kind (Type, Constraint)
+import Data.Type.Equality
+
+
+data Sparse t t' where
+ SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
+ SpAbsent :: Sparse t TNil
+
+ SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
+ SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
+ SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
+ SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpScal :: Sparse (TScal t) (TScal t)
+deriving instance Show (Sparse t t')
+
+class ApplySparse f where
+ applySparse :: Sparse t t' -> f t -> f t'
+
+instance ApplySparse STy where
+ applySparse (SpSparse s) t = STMaybe (applySparse s t)
+ applySparse SpAbsent _ = STNil
+ applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+ applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse SpScal t = t
+
+instance ApplySparse SMTy where
+ applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
+ applySparse SpAbsent _ = SMTNil
+ applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
+ applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse SpScal t = t
+
+
+class IsSubType s where
+ type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
+ subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
+ subtTrans :: s a b -> s b c -> s a c
+ subtFull :: IsSubTypeSubject s f => f t -> s t t
+
+instance IsSubType (:~:) where
+ type IsSubTypeSubject (:~:) f = ()
+ subtApply = gcastWith
+ subtTrans = trans
+ subtFull _ = Refl
+
+instance IsSubType Sparse where
+ type IsSubTypeSubject Sparse f = f ~ SMTy
+ subtApply = applySparse
+
+ subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
+ subtTrans _ SpAbsent = SpAbsent
+ subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
+ subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
+ subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
+ subtTrans SpScal SpScal = SpScal
+
+ subtFull = spDense
+
+spDense :: SMTy t -> Sparse t t
+spDense SMTNil = SpAbsent
+spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
+spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
+spDense (SMTMaybe t) = SpMaybe (spDense t)
+spDense (SMTArr _ t) = SpArr (spDense t)
+spDense (SMTScal _) = SpScal
+
+isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
+isDense SMTNil SpAbsent = Just Refl
+isDense _ SpSparse{} = Nothing
+isDense _ SpAbsent = Nothing
+isDense (SMTPair t1 t2) (SpPair s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTLEither t1 t2) (SpLEither s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTMaybe t) (SpMaybe s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTArr _ t) (SpArr s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTScal _) SpScal = Just Refl
+
+isAbsent :: Sparse t t' -> Bool
+isAbsent (SpSparse s) = isAbsent s
+isAbsent SpAbsent = True
+isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpMaybe s) = isAbsent s
+isAbsent (SpArr s) = isAbsent s
+isAbsent SpScal = False
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 159934d..d276e44 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
splitLets' = \sub -> \case
EVar _ t i -> sub t i WId
- ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
ECase x e a b ->
let STEither t1 t2 = typeOf e
in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
@@ -35,6 +35,13 @@ splitLets' = \sub -> \case
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD1 x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD2 x cm a b c ->
+ let STArr _ tB = typeOf b
+ STArr _ t2 = typeOf c
+ in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
@@ -54,6 +61,7 @@ splitLets' = \sub -> \case
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
@@ -61,9 +69,11 @@ splitLets' = \sub -> \case
EShape x e -> EShape x (splitLets' sub e)
EOp x op e -> EOp x op (splitLets' sub e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ ERecompute x e -> ERecompute x (splitLets' sub e)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
EZero x t ezi -> EZero x t (splitLets' sub ezi)
+ EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
@@ -87,15 +97,42 @@ splitLets' = \sub -> \case
-> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
split2 sub tbind1 tbind2 body =
let (ptrs1', bs1') = split @env' tbind1
- bs1 = fst (weakenBindings weakenExpr WSink bs1')
+ bs1 = fst (weakenBindingsE WSink bs1')
(ptrs2, bs2) = split @(bind1 : env') tbind2
in letBinds bs1 $
- letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
_ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env')))
t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
body
+ -- TODO: abstract this to splitN lol wtf
+ _split4 :: forall bind1 bind2 bind3 bind4 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
+ _split4 sub tbind1 tbind2 tbind3 tbind4 body =
+ let (ptrs1, bs1') = split @env' tbind1
+ (ptrs2, bs2') = split @(bind1 : env') tbind2
+ (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
+ (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4
+ bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1')
+ bs2 = fst (weakenBindingsE (WSink .> WSink) bs2')
+ bs3 = fst (weakenBindingsE WSink bs3')
+ b1 = bindingsBinds bs1
+ b2 = bindingsBinds bs2
+ b3 = bindingsBinds bs3
+ b4 = bindingsBinds bs4
+ in letBinds bs1 $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1))
+ _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink))
+ _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink))
+ _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env')))
+ t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w)))))))))
+ body
+
type family Split t where
Split (TPair a b) = SplitRec (TPair a b)
Split _ = '[]
@@ -123,11 +160,11 @@ split typ = case typ of
STPair{} -> splitRec (EVar ext typ IZ) typ
STNil -> other
STEither{} -> other
+ STLEither{} -> other
STMaybe{} -> other
STArr{} -> other
STScal{} -> other
STAccum{} -> other
- STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
@@ -142,11 +179,11 @@ splitRec rhs typ = case typ of
(p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
STEither{} -> other
+ STLEither{} -> other
STMaybe{} -> other
STArr{} -> other
STScal{} -> other
STAccum{} -> other
- STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex env '[t])
other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index efb1e04..4ddcb50 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -5,9 +5,9 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE TypeData #-}
module AST.Types where
import Data.Int (Int32, Int64)
@@ -23,12 +23,11 @@ type data Ty
= TNil
| TPair Ty Ty
| TEither Ty Ty
+ | TLEither Ty Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
| TAccum Ty -- ^ contained type must be a monoid type
- -- sparse monoid types
- | TLEither Ty Ty
type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -37,12 +36,11 @@ data STy t where
STNil :: STy TNil
STPair :: STy a -> STy b -> STy (TPair a b)
STEither :: STy a -> STy b -> STy (TEither a b)
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
STMaybe :: STy a -> STy (TMaybe a)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
- -- sparse monoid types
- STLEither :: STy a -> STy b -> STy (TLEither a b)
deriving instance Show (STy t)
instance GCompare STy where
@@ -53,6 +51,8 @@ instance GCompare STy where
STPair{} _ -> GLT ; _ STPair{} -> GGT
(STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STLEither{} _ -> GLT ; _ STLEither{} -> GGT
(STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
(STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
@@ -60,9 +60,7 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
- (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
@@ -173,15 +171,25 @@ type family ScalIsIntegral t where
ScalIsIntegral TBool = False
-- | Returns true for arrays /and/ accumulators.
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = True
-hasArrays (STLEither a b) = hasArrays a || hasArrays b
+typeHasArrays :: STy t' -> Bool
+typeHasArrays STNil = False
+typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STMaybe t) = typeHasArrays t
+typeHasArrays STArr{} = True
+typeHasArrays STScal{} = False
+typeHasArrays STAccum{} = True
+
+typeHasAccums :: STy t' -> Bool
+typeHasAccums STNil = False
+typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STMaybe t) = typeHasAccums t
+typeHasAccums STArr{} = False
+typeHasAccums STScal{} = False
+typeHasAccums STAccum{} = True
type family Tup env where
Tup '[] = TNil
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 3d5f544..a22b73f 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -1,18 +1,22 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus) where
+module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import AST
+import AST.Sparse.Types
import Data
--- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them
--- into their concrete implementations.
+-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
+-- expanding them into their concrete implementations. Also ensure that
+-- 'EAccum' has a dense sparsity.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t e -> zero t e
+ EDeepZero _ t e -> deepZero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -40,6 +44,9 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
@@ -47,12 +54,17 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
-zero SMTNil _ = ENil ext
+-- don't destroy the effects!
+zero SMTNil e = ELet ext e $ ENil ext
zero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
@@ -65,8 +77,30 @@ zero (SMTScal t) _ = case t of
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero SMTNil e = elet e $ ENil ext
+deepZero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+deepZero (SMTLEither t1 t2) e =
+ elcase e
+ (ELNil ext (fromSMTy t1) (fromSMTy t2))
+ (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
+ (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+deepZero (SMTMaybe t) e =
+ emaybe e
+ (ENothing ext (fromSMTy t))
+ (EJust ext (deepZero t (evar IZ)))
+deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
+deepZero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-plus SMTNil _ _ = ENil ext
+-- don't destroy the effects!
+plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
plus (SMTPair t1 t2) a b =
let t = STPair (fromSMTy t1) (fromSMTy t2)
in ELet ext a $
@@ -104,7 +138,7 @@ plus (SMTArr _ t) a b =
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
ELet ext arg $
@@ -142,3 +176,78 @@ onehot typ topprj idx arg = case (typ, topprj) of
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
(ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
zero t1 (EVar ext (tZeroInfo t1) IZ))
+
+accumulateSparse
+ :: SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse topty topsp arg accum = case (topty, topsp) of
+ (_, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere (ENil ext) arg
+ (SMTScal _, SpScal) ->
+ accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (_, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, SpAbsent) ->
+ ENil ext
+ (SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
+ ENil ext
+
+acPrjCompose
+ :: SAIDense dense
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
+ -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
+acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
+acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPFst p') idx'
+acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPSnd p') idx'
+acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
+acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
+acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPLeft p') idx'
+acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPRight p') idx'
+acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPJust p') idx'
+acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index d882e28..3a97fd1 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -19,6 +19,7 @@ module AST.Weaken (module AST.Weaken, Append) where
import Data.Bifunctor (first)
import Data.Functor.Const
+import Data.GADT.Compare
import Data.Kind (Type)
import Data
@@ -31,6 +32,11 @@ data Idx env t where
IS :: Idx env t -> Idx (a : env) t
deriving instance Show (Idx env t)
+instance GEq (Idx env) where
+ geq IZ IZ = Just Refl
+ geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl
+ geq _ _ = Nothing
+
splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
splitIdx SNil i = Right i
splitIdx (SCons _ _) IZ = Left IZ
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index 6752c24..c6efe37 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where
SSegNil :: SSegments '[]
SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
-instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where
+instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where
fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil
auto :: KnownListSpine list => SList (Const ()) list
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 20575b3..6301dc1 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -28,9 +28,9 @@ data ValId t where
VIPair :: ValId a -> ValId b -> ValId (TPair a b)
VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
- VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
@@ -244,6 +244,35 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> pure sh
pure (res, EMinimum1Inner res e1')
+ EReshape _ dim e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- VIArr <$> genId <*> shidsToVec dim v1
+ pure (res, EReshape res dim e1' e2')
+
+ EFold1InnerD1 _ cm e1 e2 e3 -> do
+ let t1 = typeOf e2
+ x1 <- genIds t1
+ x2 <- genIds t1
+ (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
+ (_, e2') <- idana env e2
+ (v3, e3') <- idana env e3
+ let VIArr _ sh'@(_ :< sh) = v3
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh')
+ pure (res, EFold1InnerD1 res cm e1' e2' e3')
+
+ EFold1InnerD2 _ cm ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ xf1 <- genIds t2
+ xf2 <- genIds tB
+ (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef
+ (v2, e2') <- idana env ebog
+ (_, e3') <- idana env ed
+ let VIArr _ sh@(_ :< sh') = v2
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh)
+ pure (res, EFold1InnerD2 res cm e1' e2' e3')
+
EConst _ t val -> do
res <- VIScal <$> genId
pure (res, EConst res t val)
@@ -294,6 +323,10 @@ idana env expr = case expr of
res <- genIds t4
pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
+ ERecompute _ e -> do
+ (v, e') <- idana env e
+ pure (v, ERecompute v e')
+
EWith _ t e1 e2 -> do
let t1 = typeOf e1
(_, e1') <- idana env e1
@@ -303,11 +336,11 @@ idana env expr = case expr of
let res = VIPair v2 x2
pure (res, EWith res t e1' e2')
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
- pure (VINil, EAccum VINil t prj e1' e2' e3')
+ pure (VINil, EAccum VINil t prj e1' sp e2' e3')
EZero _ t e1 -> do
-- Approximate the result of EZero to be independent from the zero info
@@ -316,6 +349,13 @@ idana env expr = case expr of
res <- genIds (fromSMTy t)
pure (res, EZero res t e1')
+ EDeepZero _ t e1 -> do
+ -- Approximate the result of EDeepZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EDeepZero res t e1')
+
EPlus _ t e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
@@ -367,11 +407,11 @@ genIds :: STy t -> IdGen (ValId t)
genIds STNil = pure VINil
genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
genIds STAccum{} = VIAccum <$> genId
-genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
shidsToVec SZ _ = pure VNil
diff --git a/src/Array.hs b/src/Array.hs
index 707dce2..6ceb9fe 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -91,6 +91,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l)
arrayToList :: Array n t -> [t]
arrayToList (Array _ v) = V.toList v
+arrayReshape :: Shape n -> Array m t -> Array n t
+arrayReshape sh (Array sh' v)
+ | shapeSize sh == shapeSize sh' = Array sh v
+ | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")"
+
arrayUnit :: t -> Array Z t
arrayUnit x = Array ShNil (V.singleton x)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index ac308ac..7594a0f 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -11,6 +12,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -33,15 +35,14 @@ module CHAD (
import Data.Functor.Const
import Data.Some
-import Data.Type.Bool (If)
import Data.Type.Equality (type (==), testEquality)
-import GHC.Stack (HasCallStack)
import Analysis.Identity (ValId(..), validSplitEither)
import AST
import AST.Bindings
import AST.Count
import AST.Env
+import AST.Sparse
import AST.Weaken.Auto
import CHAD.Accum
import CHAD.EnvDescr
@@ -62,14 +63,20 @@ tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
-bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds
- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-bindingsCollect BTop SETop _ = ENil ext
-bindingsCollect (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
+ -> binds :> env2 -> Ex env2 (Tape tapebinds)
+bindingsCollectTape SNil SETop _ = ENil ext
+bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
EPair ext (EVar ext t (w @> IZ))
- (bindingsCollect binds sub (w .> WSink))
-bindingsCollect (BPush binds _) (SENo sub) w =
- bindingsCollect binds sub (w .> WSink)
+ (bindingsCollectTape binds sub (w .> WSink))
+bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
+ bindingsCollectTape binds sub (w .> WSink)
+
+-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
+-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
+-- bindingsCollectTape' binds sub w
+-- | Refl <- lemAppendNil @binds
+-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))
-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
@@ -140,7 +147,7 @@ growRecon t ts (Reconstructor unfbs bs)
-- Add a 'fst' at the bottom of the builder stack.
-- First we have to weaken most of 'bs' to skip one more binding in the
-- unfolder stack above it.
- (BPush (fst (weakenBindings weakenExpr
+ (BPush (fst (weakenBindingsE
(wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil))
(WSink :: env :> (Tape (t : ts) : env))) bs))
(t
@@ -190,14 +197,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)
-- incidentally also add a bunch of additional bindings, namely 'Reverse
-- (TapeUnfoldings binds)', so the calling code just has to skip those in
-- whatever it wants to do.
-reconstructBindings :: SList STy binds -> Idx env (Tape binds)
- -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
+reconstructBindings :: SList STy binds
+ -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
,SList STy (Reverse (TapeUnfoldings binds)))
-reconstructBindings binds tape =
- let Reconstructor unf build = buildReconstructor binds
- in (fst $ weakenBindings weakenExpr (WIdx tape)
- (bconcat (mapBindings fromUnfExpr unf) build)
- ,sreverse (stapeUnfoldings binds))
+reconstructBindings binds =
+ (\tape -> let Reconstructor unf build = buildReconstructor binds
+ in fst $ weakenBindingsE (WIdx tape)
+ (bconcat (mapBindings fromUnfExpr unf) build)
+ ,sreverse (stapeUnfoldings binds))
---------------------------------- DERIVATIVES ---------------------------------
@@ -227,26 +234,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
-> D2Op (TScal a) t
@@ -261,11 +279,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -293,7 +311,7 @@ conv1Idx (IS i) = IS (conv1Idx i)
data Idx2 env sto t
= Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
- | Idx2Me (Idx (Select env sto "merge") t)
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
| Idx2Di (Idx (Select env sto "discr") t)
conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
@@ -314,67 +332,158 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
Idx2Di j -> Idx2Di (IS j)
conv2Idx DTop i = case i of {}
+opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+opt2UnSparse = go . opt2
+ where
+ go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+ go (STScal STI32) SpAbsent = \_ -> ENil ext
+ go (STScal STI64) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
+ go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ go (STScal STBool) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpScal = id
+ go (STScal STF64) SpScal = id
+ go STNil _ = \_ -> ENil ext
+ go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
+ go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
------------------------------------- MONOIDS -----------------------------------
-
-zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
-zeroTup SNil = ENil ext
-zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t)
+----------------------------------- SPARSITY -----------------------------------
------------------------------------- SUBENVS -----------------------------------
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal STF32) SpScal _ e = e
+expandSparse (STScal STF64) SpScal _ e = e
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
-subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
+subenvPlus :: SBool req1 -> SBool req2
+ -> SList SMTy env
+ -> SubenvS env env1 -> SubenvS env env2
+ -> (forall env3. SubenvS env env3
+ -> Injection req1 (Tup env1) (Tup env3)
+ -> Injection req2 (Tup env2) (Tup env3)
+ -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
-> r)
-> r
-subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
-subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
+
+subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e2 $
- EPair ext (pl (weakenExpr WSink e1)
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e1 $
- ELet ext (weakenExpr WSink e2) $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (EPlus ext (d2M t)
- (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ)))
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t)
+subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ Noinj
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+
+subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
+ subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
+ k sub3 minj13 minj23 (flip pl)
-assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
-assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
-assertSubenvEmpty SETop = Refl
-assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
+subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
+ k (SEYes sp3 sub3)
+ (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (tinj13 e1b))
+ (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
+ \e2 -> eunPair e2 $ \_ e2a e2b ->
+ EPair ext (inj23 e2a) (tinj23 e2b))
+ (\e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))))
+
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
--------------------------------- ACCUMULATORS ---------------------------------
@@ -407,8 +516,8 @@ accumPromote :: forall dt env sto proxy r.
-- accumulators.
-> (forall shbinds.
SList STy shbinds
- -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
- :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
-- ^ A weakening that converts a computation in the
-- revised environment to one in the original environment
-- extended with some accumulators.
@@ -422,14 +531,14 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
k (storepl `DPush` (t, vid, SAccum))
envpro
prosub
- (SEYes accrevsub)
+ (SEYesR accrevsub)
(VarMap.sink1 accumMap)
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -449,7 +558,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
k (storepl `DPush` (t, vid, SAccum))
(t `SCons` envpro)
- (SEYes prosub)
+ (SEYesR prosub)
(SENo accrevsub)
(let accumMap' = VarMap.sink1 accumMap
in case fromArrayValId vid of
@@ -466,7 +575,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
-- Discrete values are left as-is, nothing to do
@@ -484,6 +593,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
STNil -> True
STPair a b -> isDiscrete a && isDiscrete b
STEither a b -> isDiscrete a && isDiscrete b
+ STLEither a b -> isDiscrete a && isDiscrete b
STMaybe a -> isDiscrete a
STArr _ a -> isDiscrete a
STScal st -> case st of
@@ -493,26 +603,45 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
STF64 -> False
STBool -> True
STAccum{} -> False
- STLEither a b -> isDiscrete a && isDiscrete b
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
-data Ret env0 sto t =
- forall shbinds tapebinds env0Merge.
+data Ret env0 sto sd t =
+ forall shbinds tapebinds contribs.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (Ret env0 sto t)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+deriving instance Show (Ret env0 sto sd t)
+
+type data TyTyPair = MkTyTyPair Ty Ty
+
+data SingleRet env0 sto (pair :: TyTyPair) =
+ forall shbinds tapebinds.
+ SingleRet
+ (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (RetPair env0 sto (D1E env0) shbinds tapebinds pair)
-data RetPair env0 sto env shbinds tapebinds t =
- forall env0Merge.
- RetPair (Ex (Append shbinds env) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
+-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
+-- -> Subenv shbinds tapebinds
+-- -> Ex (Append shbinds (D1E env0)) (D1 t)
+-- -> SubenvS (D2E (Select env0 sto "merge")) contribs
+-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+-- -> SingleRet env0 sto (MkTyTyPair sd t)
+-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
+-- {-# COMPLETE Ret1 #-}
+
+data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
+ RetPair :: forall sd t contribs -- existentials
+ env0 sto env shbinds tapebinds. -- universals
+ Ex (Append shbinds env) (D1 t)
+ -> SubenvS (D2E (Select env0 sto "merge")) contribs
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
+deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
data Rets env0 sto env list =
forall shbinds tapebinds.
@@ -521,113 +650,149 @@ data Rets env0 sto env list =
(SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)
+toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
+toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+
weakenRetPair :: SList STy shbinds -> env :> env'
- -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t
+ -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
weakenRets w (Rets binds tapesub list) =
- let (binds', _) = weakenBindings weakenExpr w binds
+ let (binds', _) = weakenBindingsE w binds
in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f.
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
Descr env0 sto
-> SList f b1 -> SList f b2
-> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
- -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t
- -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t
-rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (autoWeak
- (#d (auto1 @(D2 t))
- &. #t2 (subList b2 subtape2)
- &. #t1 (subList b1 subtape1)
- &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#t2 :++: #tl))
- (#d :++: ((#t2 :++: #t1) :++: #tl)))
- d)
+ RetPair e1 sub
+ (weakenExpr (autoWeak
+ (#d (auto1 @sd)
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ e2)
-retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat _ SNil = Rets BTop SETop SNil
-retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list)
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
| Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
- <- weakenRets (sinkWithBindings b) (retConcat descr list)
+ <- weakenRets (sinkWithBindings e0) (retConcat descr list)
, Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
, Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
- = Rets (bconcat b binds)
+ = Rets (bconcat e0 binds)
(subenvConcat subtape subtape2)
- (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
sub
- (weakenExpr (WCopy (sinkWithSubenv subtape2)) d))
- (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
subtape subtape2)
pairs))
freezeRet :: Descr env sto
- -> Ret env sto t
+ -> Ret env sto (D2 t) t
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
- let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
+ let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tape (subList (bindingsBinds e0) subtape)
- &. #shbinds (bindingsBinds e0)
- &. #d2ace (d2ace (select SAccum descr))
- &. #tl (desD1E descr))
+ (ELet ext (weakenExpr (autoWeak library
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
- expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
---------------------------- THE CHAD TRANSFORMATION ---------------------------
-drev :: forall env sto t.
+drev :: forall env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
- -> Expr ValId env t -> Ret env sto t
-drev des accumMap = \case
+ -> Sparse (D2 t) sd
+ -> Expr ValId env t -> Ret env sto sd t
+drev des _ sd | isAbsent sd =
+ \e ->
+ Ret BTop
+ SETop
+ (drevPrimal des e)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+drev _ _ SpAbsent = error "Absent should be isAbsent"
+
+drev des accumMap (SpSparse sd) =
+ \e ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ e1
+ sub'
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
+ (inj2 (ENil ext))
+ (inj1 (weakenExpr (WCopy WSink) e2)))
+ }
+
+drev des accumMap sd = \case
EVar _ t i ->
case conv2Idx des i of
Idx2Ac accI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
- (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ (subenvNone (d2e (select SMerge des)))
+ (let ty = applySparse sd (d2M t)
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvOnehot (select SMerge des) tupI)
- (EPair ext (ENil ext) (EVar ext (d2 t) IZ))
+ (subenvOnehot (d2e (select SMerge des)) tupI sd)
+ (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
Idx2Di _ ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
ELet _ (rhs :: Expr _ _ a) body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs
- , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body
- , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
- subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
- Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
- (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
+ , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
+ , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
+ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
+ Ret (bconcat (rhs0 `bpush` rhs1) body0')
+ (subenvConcat subtapeRHS subtapeBody)
(weakenExpr wbody0' body1)
subBoth
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds body0) subtapeBody)
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
&. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
@@ -637,317 +802,461 @@ drev des accumMap = \case
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EVar ext (contribTupTy des subRHS) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
- | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil
- , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
- subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
+ | SpPair sd1 sd2 <- sd
+ , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
+ , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
subtape
(EPair ext a1 b1)
subBoth
- (EMaybe ext
- (zeroTup (subList (select SMerge des) subBoth))
- (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
- (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
- (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (contribTupTy des subA) (IS IZ))
+ (EVar ext (contribTupTy des subB) IZ))
EFst _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
+ , STPair t1 _ <- typeOf e ->
Ret e0
subtape
(EFst ext e1)
sub
- (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $
+ (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
+ , STPair _ t2 <- typeOf e ->
Ret e0
subtape
(ESnd ext e1)
sub
- (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
+ -- Don't need to handle ENil, because its cotangent is always absent!
+ -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)
EInl _ t2 e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInl ext (d1 t2) e1)
- sub
+ sub'
(ELCase ext
- (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ)
- (zeroTup (subList (select SMerge des) sub))
- (weakenExpr (WCopy WSink) e2)
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
+ (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
+ (inj2 $ ENil ext)
+ (inj1 $ weakenExpr (WCopy WSink) e2)
+ (EError ext (contribTupTy des sub') "inl<-dinr"))
EInr _ t1 e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInr ext (d1 t1) e1)
- sub
+ sub'
(ELCase ext
- (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ)
- (zeroTup (subList (select SMerge des) sub))
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
- (weakenExpr (WCopy WSink) e2))
+ (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
+ (inj2 $ ENil ext)
+ (EError ext (contribTupTy des sub') "inr<-dinl")
+ (inj1 $ weakenExpr (WCopy WSink) e2))
ECase _ e (a :: Expr _ _ t) b
- | STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e
- , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
- , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
+ | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
+ , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
+ , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
, let (bindids1, bindids2) = validSplitEither (extOf e)
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
+ <- drevScoped des accumMap t1 storage1 bindids1 sd a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
+ <- drevScoped des accumMap t2 storage2 bindids2 sd b
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
- , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
- , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB)
- , let collectA = bindingsCollect a0 subtapeA
- , let collectB = bindingsCollect b0 subtapeB
+ , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , let tapeA = tapeTy subtapeListA
+ , let tapeB = tapeTy subtapeListB
+ , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
+ (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
+ (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
- , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
- , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
+ , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0
+ , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0
+ , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
+ , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
+ , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
+ , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
+ , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
+ , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
->
- subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ ->
- subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
- let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in
- Ret (e0 `BPush`
- (tPrimal,
- ECase ext e1
- (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
- (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
- (SEYes subtapeE)
+ subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
+ Ret (e0 `bpush` ECase ext e1
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))
+ (SEYesR subtapeE)
(EFst ext (EVar ext tPrimal IZ))
subOut
- (ELet ext
+ (elet
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ
- in letBinds rebinds $
- ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA
+ in letBinds (rebinds IZ) $
ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #ta0 (subList (bindingsBinds a0) subtapeA)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #ta0 subtapeListA
&. #prea0 prerebinds
- &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #ta0 :++: #tl)
(#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))
- (ELInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ
- in letBinds rebinds $
+ EPair ext (sAB_A $ EFst ext (evar IZ))
+ (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB
+ in letBinds (rebinds IZ) $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tb0 (subList (bindingsBinds b0) subtapeB)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #tb0 subtapeListB
&. #preb0 prerebinds
- &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #tb0 :++: #tl)
(#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))
- (ELInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
- ELet ext
- (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
+ EPair ext (sAB_B $ EFst ext (evar IZ))
+ (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
plus_AB_E
- (EFst ext (EVar ext tCaseRet (IS IZ)))
- (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))
+ (EFst ext (evar IZ))
+ (ELet ext (ESnd ext (evar IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2))
EConst _ t val ->
Ret BTop
SETop
(EConst ext t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EOp _ op e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
case d2op op of
Linear d2opfun ->
Ret e0
subtape
(d1op op e1)
sub
- (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))
+ (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy WSink) e2))
Nonlinear d2opfun ->
- Ret (e0 `BPush` (d1 (typeOf e), e1))
- (SEYes subtape)
+ Ret (e0 `bpush` e1)
+ (SEYesR subtape)
(d1op op $ EVar ext (d1 (typeOf e)) IZ)
sub
(ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
- (EVar ext (d2 (opt2 op)) IZ))
+ (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
- ECustom _ _ _ storety _ pr du a b
+ ECustom _ _ tb _ srce pr du a b
-- allowed to ignore a2 because 'a' is the part of the input that is inactive
- | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
- <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil ->
- Ret (binds `BPush` (typeOf a1, a1)
- `BPush` (typeOf b1, weakenExpr WSink b1)
- `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
- `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
- (SEYes (SENo (SENo (SENo subtape))))
- (EFst ext (EVar ext (typeOf pr) (IS IZ)))
- bsub
- (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
- weakenExpr (WCopy (WSink .> WSink)) b2)
+ | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
+ case isDense (d2M (typeOf srce)) sd of
+ Just Refl ->
+ Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a)
+ `bpush` weakenExpr WSink b1
+ `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)
+ `bpush` ESnd ext (EVar ext (typeOf pr) IZ))
+ (SEYesR (SENo (SENo (SENo bsubtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
+ Nothing ->
+ Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a)
+ `bpush` weakenExpr WSink b1
+ `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
+ (SEYesR (SENo (SENo bsubtape)))
+ (EFst ext (EVar ext (typeOf pr) IZ))
+ bsub
+ (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape
+ ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent
+ (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
+ (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
+ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)
+
+ ERecompute _ e ->
+ deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
+ let smallE = unsafeWeakenWithSubenv usedSub e in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
+ let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
+ Ret (collectBindings (desD1E des) subD1eUsed)
+ (subenvAll (desD1E usedDes))
+ (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
+ (subenvCompose subMergeUsed' sub)
+ (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
+ weakenExpr
+ (autoWeak (#d (auto1 @sd)
+ &. #shbinds (bindingsBinds e0)
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #d1env (desD1E usedDes)
+ &. #tl' (d2ace (select SAccum usedDes))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
+ (#shbinds :++: #d :++: #d1env :++: #tl))
+ e2)
+ }
EError _ t s ->
Ret BTop
SETop
(EError ext (d1 t) s)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EConstArr _ n t val ->
Ret BTop
SETop
(EConstArr ext n t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
- | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result
+ | SpArr @_ @sdElt sdElt <- sd
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
- deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
- subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
+ let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
- case assertSubenvEmpty sub of { Refl ->
+ case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 ->
+ case lemAppendNil @e_binds of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollect e0 subtapeE in
- Ret (BTop `BPush` (shty, letBinds she0 she1)
- `BPush` (STArr ndim (STPair (d1 eltty) tapety)
- ,EBuild ext ndim
- (EVar ext shty IZ)
- (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#ix :++: #sh :++: #d1env))
- e0)) $
- let w = autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #e0 (bindingsBinds e0)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#e0 :++: #ix :++: #sh :++: #d1env)
- in EPair ext (weakenExpr w e1) (collectexpr w)))
- `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
- (SEYes (SENo (SEYes SETop)))
- (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
- (subenvCompose subMergeUsed proSub)
- (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- EMaybe ext
- (zeroTup envPro)
- (ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
- weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (auto1 @(Tape e_tape))
- &. #ix (auto1 @shty)
- &. #darr (auto1 @(TArr ndim (D2 eltty)))
- &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty))))
- &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
- &. #sh (auto1 @shty)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
- (EVar ext (d2 (STArr ndim eltty)) IZ))
+ let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ Ret (mergePrimalBindings
+ `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)
+ `bpush` EBuild ext ndim
+ (EVar ext shty IZ)
+ (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#ix :++: #sh :++: #propr :++: #d1env))
+ e0)) $
+ let w = autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env)
+ w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
+ in EPair ext (weakenExpr w e1) (collectexpr w'))
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))
+ (SEYesR (SENo (SEYesR (subenvAll (d1e envPro)))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in
+ ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE)
+ in letBinds (rebinds IZ) $
+ weakenExpr (autoWeak (#d (auto1 @sdElt)
+ &. #pro (d2ace envPro)
+ &. #etape (subList (bindingsBinds e0) subtapeE)
+ &. #prerebinds prerebinds
+ &. #tape (auto1 @(Tape e_tape))
+ &. #ix (auto1 @shty)
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
+ &. #sh (auto1 @shty)
+ &. #propr (d1e envPro)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds e0) subtapeE))
+ e2)
}}
+ EFold1Inner _ commut origef ex₀ earr
+ | SpArr @_ @sdElt sdElt <- sd
+ , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
+ , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
+ deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') ->
+ let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in
+ case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
+ let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
+ &. #parr (auto1 @(TArr (S n) (D1 elt)))
+ &. #px₀ (auto1 @(D1 elt))
+ &. #px (auto1 @(D1 elt))
+ &. #pzi (auto1 @(ZeroInfo (D2 elt)))
+ &. #primal (primalTy `SCons` SNil)
+ &. #darr (auto1 @(TArr n sdElt))
+ &. #d (auto1 @(D2 elt))
+ &. #x₀abinds (bindingsBinds bindsx₀a)
+ &. #fbinds (bindingsBinds ef0)
+ &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
+ &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
+ &. #ftape (auto1 @(Tape e_tape))
+ &. #primalzip (zipPrimalTy `SCons` SNil)
+ &. #efPrerebinds efPrerebinds
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace envPro)
+ &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
+ wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a mergePrimalBindings'
+ `bpush` weakenExpr wOverPrimalBindings ex₀1
+ `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
+ `bpush` EFold1InnerD1 ext commut
+ (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
+ letBinds (fst (weakenBindingsE (autoWeak library
+ (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ layout)
+ ef0)) $
+ elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: layout))
+ ef1) $
+ EPair ext
+ (evar IZ)
+ (EPair ext
+ (evar IZ)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout)))))
+ (EVar ext (d1 eltty) (IS (IS IZ)))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
+ (EFst ext (EVar ext primalTy IZ))
+ subx₀af
+ (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ elet
+ (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $
+ makeAccumulators (autoWeak library #propr layout1) envPro $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $
+ elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $
+ elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $
+ letBinds (efRebinds (IS (IS IZ))) $
+ let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in
+ elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
+ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
+ .> wPro (subList (bindingsBinds ef0) subtapeEf))
+ ef2) $
+ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
+ (ezip
+ (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
+ plus_x₀a_f
+ (plus_x₀_a
+ (elet (EIdx0 ext
+ (EFold1Inner ext Commut
+ (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
+ (eflatten (EFst ext (EFst ext (evar IZ)))))) $
+ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
+ ex₀2)
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (ESnd ext (evar IZ)))
+ }
+
EUnit _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpArr sdElt <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
Ret e0
subtape
(EUnit ext e1)
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EReplicate1Inner _ en e
- -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
- <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil
+ -- We're allowed to differentiate 'en' as primal-only here because its output is discrete.
+ | SpArr sdElt <- sd
, let STArr ndim eltty = typeOf e ->
- Ret binds
- subtape
- (EReplicate1Inner ext en1 e1)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EFold1Inner ext Commut
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (ezeroD2 eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
+ -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero.
+ sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 ->
+ Ret binds
+ subtape
+ (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
+ sub
+ (ELet ext (EFold1Inner ext Commut
+ (sparsePlus (d2M eltty) sdElt'
+ (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
+ (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (inj2 (ENil ext))
+ (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+ }
EIdx0 _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e
, STArr _ t <- typeOf e ->
Ret e0
subtape
(EIdx0 ext e1)
sub
- (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
- weakenExpr (WCopy WSink) e2)
+ (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
{-
@@ -956,9 +1265,9 @@ drev des accumMap = \case
| Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
, STArr (SS n) eltty <- typeOf e ->
- Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
- `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
- (SEYes (SENo subtape))
+ Ret (binds `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))
+ (SEYesR (SENo subtape))
(EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
(weakenExpr (WSink .> WSink) ei1))
sub
@@ -969,60 +1278,73 @@ drev des accumMap = \case
-}
EIdx _ e ei
- -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
- , STArr n eltty <- typeOf e
+ -- We're allowed to differentiate ei as primal because its output is discrete.
+ | STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n
- , Refl <- lemZeroInfoD2 eltty
- , let tIxN = tTup (sreplicate n tIx) ->
- Ret (binds `BPush` (STArr n (d1 eltty), e1)
- `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
- `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
- (SEYes (SEYes (SENo subtape)))
- (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- sub
- (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere))
- (EPair ext (EPair ext (EVar ext tIxN (IS IZ))
- (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext)))
- (ENil ext))
- (EVar ext (d2 eltty) IZ)) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ , let tIxN = tTup (sreplicate n tIx) ->
+ sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 ->
+ Ret (binds `bpush` e1
+ `bpush` EShape ext (EVar ext (typeOf e1) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))
+ (SEYesR (SEYesR (SENo subtape)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ sub
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
EShape _ e
- -- Allowed to ignore e2 here because the output of EShape is discrete,
- -- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des accumMap e
- , STArr n _ <- typeOf e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret e0
- subtape
- (EShape ext e1)
- (subenvNone (select SMerge des))
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
, STArr (SS n) t <- typeOf e ->
- Ret (e0 `BPush` (STArr (SS n) t, e1)
- `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
- (SEYes (SENo subtape))
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ))
+ (SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 (STArr n t)) IZ))
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
- -- These should be the next to be implemented, I think
- EFold1Inner{} -> err_unsupported "EFold1Inner"
+ EReshape _ n esh e
+ | SpArr sd' <- sd
+ , STArr orign t <- typeOf e
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , Refl <- indexTupD1Id n ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ))
+ (SEYesR (SENo subtape))
+ (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh))
+ (EVar ext (STArr orign (d1 t)) (IS IZ)))
+ sub
+ (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
@@ -1033,104 +1355,130 @@ drev des accumMap = \case
ELCase{} -> err_unsupported "ELCase"
EWith{} -> err_accum
- EAccum{} -> err_accum
EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
+ err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program"
- deriv_extremum :: ScalIsNumeric t' ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
- deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , at@(STArr (SS n) t@(STScal st)) <- typeOf e
- , let at' = STArr n t
- , let tIxN = tTup (sreplicate (SS n) tIx) =
- Ret (e0 `BPush` (at, e1)
- `BPush` (at', extremum (EVar ext at IZ)))
- (SEYes (SEYes subtape))
- (EVar ext at' IZ)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext
- (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
- eif (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
- (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
- (ezeroD2 t))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 at') IZ))
+ contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
+ contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `bpush` e1
+ `bpush` extremum (EVar ext at IZ))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
-data RetScoped env0 sto a s t =
- forall shbinds tapebinds env0Merge.
+data RetScoped env0 sto a s sd t =
+ forall shbinds tapebinds contribs sa.
RetScoped
(Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
- (Subenv shbinds tapebinds)
+ (Subenv (Append shbinds '[D1 a]) tapebinds)
(Ex (Append shbinds (D1E (a : env0))) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
-- ^ merge contributions to the _enclosing_ merge environment
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
- (If (s == "discr") (Tup (D2E env0Merge))
- (TPair (Tup (D2E env0Merge)) (D2 a))))
+ (Sparse (D2 a) sa)
+ -- ^ contribution to the argument
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) sa)))
-- ^ the merge contributions, plus the cotangent to the argument
-- (if there is any)
-deriving instance Show (RetScoped env0 sto a s t)
+deriving instance Show (RetScoped env0 sto a s sd t)
-drevScoped :: forall a s env sto t.
+drevScoped :: forall a s env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> STy a -> Storage s -> Maybe (ValId a)
+ -> Sparse (D2 t) sd
-> Expr ValId (a : env) t
- -> RetScoped env sto a s t
-drevScoped des accumMap argty argsto argids expr = case argsto of
+ -> RetScoped env sto a s sd t
+drevScoped des accumMap argty argsto argids sd expr = case argsto of
SMerge
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
case sub of
- SEYes sub' -> RetScoped e0 subtape e1 sub' e2
- SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty))
+ SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
+ SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
SAccum
- | Just (VIArr i _) <- argids
+ | chcSmartWith ?config
+ , Just (VIArr i _) <- argids
, Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
, Just Refl <- testEquality foundTy (STAccum (d2M argty))
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr ->
- RetScoped e0 subtape e1 sub $
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ -- Our contribution to the binding's cotangent _here_ is zero (absent),
+ -- because we're contributing to an earlier binding of the same value
+ -- instead.
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ weakenExpr (autoWeak (#d (auto1 @sd)
&. #body (subList (bindingsBinds e0) subtape)
&. #ac (auto1 @(TAccum (D2 a)))
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: #body :++: #tl))
- -- Our contribution to the binding's cotangent _here_ is
- -- zero, because we're contributing to an earlier binding
- -- of the same value instead.
- (EPair ext e2 (ezeroD2 argty))
+ (EPair ext e2 (ENil ext))
| let accumMap' = case argids of
Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
_ -> VarMap.sink1 accumMap
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr ->
- RetScoped e0 subtape e1 sub $
- EWith ext (d2M argty) (ezeroD2 argty) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds e0) subtape)
- &. #ac (auto1 @(TAccum (D2 a)))
- &. #tl (d2ace (select SAccum des)))
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ let library = #d (auto1 @sd)
+ &. #p (auto1 @(D1 a))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des))
+ in
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
+ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ weakenExpr (autoWeak library
(#d :++: #body :++: #ac :++: #tl)
- (#ac :++: #d :++: #body :++: #tl))
+ (#ac :++: #d :++: (#body :++: #p) :++: #tl))
e2
SDiscr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
- RetScoped e0 subtape e1 sub e2
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+
+-- TODO: proper primal-only transform that doesn't depend on D1 = Id
+drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal des e
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
+ = mapExt (const ext) e
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index d8a71b5..a7bc53f 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -1,18 +1,59 @@
-{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+-- | TODO this module is a grab-bag of random utility functions that are shared
+-- between CHAD and CHAD.Top.
module CHAD.Accum where
import AST
import CHAD.Types
import Data
+import AST.Env
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators SNil e = e
-makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t =
- makeAccumulators envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e
+d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
+d2deepZeroInfo STNil _ = ENil ext
+d2deepZeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
+d2deepZeroInfo (STEither a b) e =
+ ECase ext e
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STLEither a b) e =
+ elcase e
+ (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STMaybe a) e =
+ emaybe e
+ (ENothing ext (tDeepZeroInfo (d2M a)))
+ (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
+d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
+d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+-- The weakening is necessary because we need to initialise the created
+-- accumulators with zeros. Those zeros are deep and need full primals. This
+-- means, in the end, that primals corresponding to environment entries
+-- promoted to an accumulator with accumPromote in CHAD need to be stored for
+-- the dual.
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
@@ -25,3 +66,7 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs
index 4c287d7..49ae0e6 100644
--- a/src/CHAD/EnvDescr.hs
+++ b/src/CHAD/EnvDescr.hs
@@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env'
-> r)
-> r
subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k =
+subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
- SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e)
- SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e)
+ SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e)
+ SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e)
+ SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e)
subDescr (des `DPush` (_, _, sto)) (SENo sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
@@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des
select s@SAccum (DPush des (_, _, SDiscr)) = select s des
select s@SMerge (DPush des (_, _, SDiscr)) = select s des
select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des)
+
+selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s)
+selectSub _ DTop = SETop
+selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 9e7e7f5..4814bdf 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -12,6 +12,8 @@ module CHAD.Top where
import Analysis.Identity
import AST
+import AST.Env
+import AST.Sparse
import AST.SplitLets
import AST.Weaken.Auto
import CHAD
@@ -41,39 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
accumDescr SNil k = k DTop
accumDescr (t `SCons` env) k = accumDescr env $ \des ->
- if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
- else k (des `DPush` (t, Nothing, SMerge))
-
-d1Identity :: STy t -> D1 t :~: t
-d1Identity = \case
- STNil -> Refl
- STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STMaybe t | Refl <- d1Identity t -> Refl
- STArr _ t | Refl <- d1Identity t -> Refl
- STScal _ -> Refl
- STAccum{} -> error "Accumulators not allowed in input program"
- STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
-
-d1eIdentity :: SList STy env -> D1E env :~: env
-d1eIdentity SNil = Refl
-d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+ if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum))
+ else k (des `DPush` (t, Nothing, SMerge))
reassembleD2E :: Descr env sto
+ -> D1E env :> env'
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
-reassembleD2E DTop _ = ENil ext
-reassembleD2E (des `DPush` (_, _, SAccum)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
- (ESnd ext (EVar ext (typeOf e) IZ))))
- (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (_, _, SMerge)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
- (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
- (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t)
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
@@ -83,21 +71,22 @@ chad config env (term :: Ex env t)
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
- makeAccumulators (select SAccum descr) $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr VarMap.empty term')) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
- (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
- (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term')
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
where
term' = identityAnalysis env (splitLets term)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 74e7dbd..44ac20e 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -1,8 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
+import AST.Accum
import AST.Types
import Data
@@ -11,19 +13,19 @@ type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
D1 (TEither a b) = TEither (D1 a) (D1 b)
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
- D1 (TLEither a b) = TLEither (D1 a) (D1 b)
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
D2 (TEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TMaybe (TArr n (D2 t))
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
- D2 (TLEither a b) = TLEither (D2 a) (D2 b)
type family D2s t where
D2s TI32 = TNil
@@ -48,11 +50,11 @@ d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
-d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
@@ -60,10 +62,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env
d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
-d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
-d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
+d2M (STArr n t) = SMTArr n (d2M t)
d2M (STScal t) = case t of
STI32 -> SMTNil
STI64 -> SMTNil
@@ -71,7 +74,6 @@ d2M (STScal t) = case t of
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
-d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
@@ -95,6 +97,8 @@ data CHADConfig = CHADConfig
chcCaseArrayAccum :: Bool
, -- | Introduce top-level arguments containing arrays in accumulator mode.
chcArgArrayAccum :: Bool
+ , -- | Place with-blocks around array variable scopes, and redirect accumulations there.
+ chcSmartWith :: Bool
}
deriving (Show)
@@ -103,12 +107,14 @@ defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
, chcArgArrayAccum = False
+ , chcSmartWith = False
}
chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
, chcCaseArrayAccum = True
- , chcArgArrayAccum = True }
+ , chcArgArrayAccum = True
+ , chcSmartWith = True }
------------------------------------ LEMMAS ------------------------------------
@@ -116,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl
+
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index 87c01cb..888fed4 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -19,29 +19,25 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
- STPair t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal
STEither t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
- STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t -> case der of
- Nothing -> arrayMap (zeroTan t) primal
- Just d
- | arrayShape primal == arrayShape d ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
- STScal sty -> case sty of
- STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
- STAccum{} -> error "Accumulators not allowed in input program"
STLEither t1 t2 -> case (primal, der) of
(_, Nothing) -> Nothing
(Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
(Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
_ -> error "Primal and cotangent disagree on LEither alternative"
+ STMaybe t -> liftA2 (toTan t) primal der
+ STArr _ t
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
+ STScal sty -> case sty of
+ STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
+ STAccum{} -> error "Accumulators not allowed in input program"
diff --git a/src/Compile.hs b/src/Compile.hs
index 6ba3a39..f2063ee 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -34,6 +34,7 @@ import Foreign
import GHC.Exts (int2Word#, addr2Int#)
import GHC.Num (integerFromWord#)
import GHC.Ptr (Ptr(..))
+import GHC.Stack (HasCallStack)
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
import System.IO.Error (mkIOError, userErrorType)
@@ -45,6 +46,7 @@ import qualified Prelude
import Array
import AST
import AST.Pretty (ppSTy, ppExpr)
+import AST.Sparse.Types (isDense)
import Compile.Exec
import Data
import Interpreter.Rep
@@ -77,7 +79,7 @@ compile = \env expr -> do
let (source, offsets) = compileToString codeID env expr
when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
- lib <- buildKernel source ["kernel"]
+ lib <- buildKernel source "kernel"
let result_type = typeOf expr
result_size = sizeofSTy result_type
@@ -86,7 +88,7 @@ compile = \env expr -> do
allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
serialiseArguments args ptr $ do
- callKernelFun "kernel" lib ptr
+ callKernelFun lib ptr
ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
when (ok /= 1) $
ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
@@ -125,7 +127,7 @@ data CExpr
| CECall String [CExpr] -- ^ function(arg1, ..., argn)
| CEBinop CExpr String CExpr -- ^ expr + expr
| CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
- | CECast String CExpr -- ^ (<type)<expr>
+ | CECast String CExpr -- ^ (<type>)<expr>
deriving (Show)
printStructDecl :: StructDecl -> ShowS
@@ -214,23 +216,31 @@ repSTy (STScal st) = case st of
STBool -> "uint8_t"
repSTy t = genStructName t
-genStructName :: STy t -> String
-genStructName = \t -> "ty_" ++ gen t where
- -- all tags start with a letter, so the array mangling is unambiguous.
- gen :: STy t -> String
- gen STNil = "n"
- gen (STPair a b) = 'P' : gen a ++ gen b
- gen (STEither a b) = 'E' : gen a ++ gen b
- gen (STMaybe t) = 'M' : gen t
- gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
- gen (STScal st) = case st of
- STI32 -> "i"
- STI64 -> "j"
- STF32 -> "f"
- STF64 -> "d"
- STBool -> "b"
- gen (STAccum t) = 'C' : gen (fromSMTy t)
- gen (STLEither a b) = 'L' : gen a ++ gen b
+genStructName, genArrBufStructName :: STy t -> String
+(genStructName, genArrBufStructName) =
+ (\t -> "ty_" ++ gen t
+ ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension
+ t -> error $ "genArrBufStructName: not an array type: " ++ show t)
+ where
+ -- all tags start with a letter, so the array mangling is unambiguous.
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
+
+-- The subtrees contain structs used in the bodies of the structs in this node.
+data StructTree = TreeNode [StructDecl] [StructTree]
+ deriving (Show)
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
@@ -238,60 +248,56 @@ genStructName = \t -> "ty_" ++ gen t where
--
-- For accumulation it is important that for struct representations of monoid
-- types, the all-zero-bytes value corresponds to the zero value of that type.
-genStruct :: String -> STy t -> [StructDecl]
-genStruct name topty = case topty of
+buildStructTree :: STy t -> StructTree
+buildStructTree topty = case topty of
STNil ->
- [StructDecl name "" com]
+ TreeNode [StructDecl name "" com] []
STPair a b ->
- [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ [buildStructTree a, buildStructTree b]
STEither a b -> -- 0 -> l, 1 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
STMaybe t -> -- 0 -> nothing, 1 -> just
- [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ [buildStructTree t]
STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
- -- TODO: put shape in the main struct, not the buffer; it's constant, after all
-- TODO: no buffer if n = 0
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
+ TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") ""
+ ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com]
+ [buildStructTree t]
STScal _ ->
- []
+ TreeNode [] []
STAccum t ->
- [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
- STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
+ ,StructDecl name (name ++ "_buf *buf;") com]
+ [buildStructTree (fromSMTy t)]
where
+ name = genStructName topty
com = ppSTy 0 topty
-- State: already-generated (skippable) struct names
-- Writer: the structs in declaration order
-genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()
-genStructs ty = do
- let name = genStructName ty
- seen <- lift $ gets (name `Set.member`)
-
- if seen
- then pure ()
- else do
- -- already mark this struct as generated now, so we don't generate it
- -- twice (unnecessary because no recursive types, but y'know)
- lift $ modify (Set.insert name)
-
- () <- case ty of
- STNil -> pure ()
- STPair a b -> genStructs a >> genStructs b
- STEither a b -> genStructs a >> genStructs b
- STMaybe t -> genStructs t
- STArr _ t -> genStructs t
- STScal _ -> pure ()
- STAccum t -> genStructs (fromSMTy t)
- STLEither a b -> genStructs a >> genStructs b
-
- tell (BList (genStruct name ty))
+genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructTreeW (TreeNode these deps) = do
+ seen <- lift get
+ case filter ((`Set.notMember` seen) . nameOf) these of
+ [] -> pure ()
+ structs -> do
+ lift $ modify (Set.fromList (map nameOf structs) <>)
+ mapM_ genStructTreeW deps
+ tell (BList structs)
+ where
+ nameOf (StructDecl name _ _) = name
genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty
+genAllStructs tys =
+ let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys
+ in toList (evalState (execWriterT m) mempty)
data CompState = CompState
{ csStructs :: Set (Some STy)
@@ -340,6 +346,12 @@ emitStruct ty = CompM $ do
modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
return (genStructName ty)
+-- | Also returns the name of the array buffer struct
+emitArrStruct :: STy t -> CompM (String, String)
+emitArrStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty, genArrBufStructName ty)
+
emitTLD :: String -> CompM ()
emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
@@ -463,6 +475,15 @@ serialise topty topval ptr off k =
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
(STMaybe _, Nothing) -> do
pokeByteOff ptr off (0 :: Word8)
k
@@ -471,19 +492,18 @@ serialise topty topval ptr off k =
serialise t x ptr (off + alignmentSTy t) k
(STArr n t, Array sh vec) -> do
let eltsz = sizeofSTy t
- allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
+ allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do
when debugRefc $
hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
pokeByteOff ptr off bufptr
+ pokeShape ptr (off + 8) n sh
- pokeShape bufptr 0 n sh
- pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63)
+ pokeByteOff @Word64 bufptr 0 (2 ^ 63)
- let off1 = fromSNat n * 8 + 8
- loop i
+ let loop i
| i == shapeSize sh = k
| otherwise =
- serialise t (vec V.! i) bufptr (off1 + i * eltsz) $
+ serialise t (vec V.! i) bufptr (8 + i * eltsz) $
loop (i+1)
loop 0
(STScal sty, x) -> case sty of
@@ -493,15 +513,6 @@ serialise topty topval ptr off k =
STF64 -> pokeByteOff ptr off (x :: Double) >> k
STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
(STAccum{}, _) -> error "Cannot serialise accumulators"
- (STLEither _ _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STLEither a _, Just (Left x)) -> do
- pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STLEither _ b, Just (Right y)) -> do
- pokeByteOff ptr off (2 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
-- | Assumes that this is called at the correct alignment.
deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
@@ -518,6 +529,13 @@ deserialise topty ptr off =
if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
STMaybe t -> do
tag <- peekByteOff @Word8 ptr off
if tag == 0
@@ -525,13 +543,12 @@ deserialise topty ptr off =
else Just <$> deserialise t ptr (off + alignmentSTy t)
STArr n t -> do
bufptr <- peekByteOff @(Ptr ()) ptr off
- sh <- peekShape bufptr 0 n
- refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
+ sh <- peekShape ptr (off + 8) n
+ refc <- peekByteOff @Word64 bufptr 0
when debugRefc $
hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
- let off1 = 8 * fromSNat n + 8
- eltsz = sizeofSTy t
- arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
+ let eltsz = sizeofSTy t
+ arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz))
when (refc < 2 ^ 62) $ free bufptr
return arr
STScal sty -> case sty of
@@ -541,13 +558,6 @@ deserialise topty ptr off =
STF64 -> peekByteOff @Double ptr off
STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
STAccum{} -> error "Cannot serialise accumulators"
- STLEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
- 0 -> return Nothing
- 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
- 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
- _ -> error "Invalid tag value"
align :: Int -> Int -> Int
align a off = (off + a - 1) `div` a * a
@@ -569,10 +579,14 @@ metricsSTy (STEither a b) =
let (a1, s1) = metricsSTy a
(a2, s2) = metricsSTy b
in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
-metricsSTy (STArr _ _) = (8, 8)
+metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n)
metricsSTy (STScal sty) = case sty of
STI32 -> (4, 4)
STI64 -> (8, 8)
@@ -580,10 +594,6 @@ metricsSTy (STScal sty) = case sty of
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
-metricsSTy (STLEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -747,15 +757,17 @@ compile' env = \case
return (CELit retvar)
EConstArr _ n t (Array sh vec) -> do
- strname <- emitStruct (STArr n (STScal t))
+ (strname, bufstrname) <- emitArrStruct (STArr n (STScal t))
tldname <- genName' "carraybuf"
-- Give it a refcount of _half_ the size_t max, so that it can be
-- incremented and decremented at will and will "never" reach anything
-- where something happens
- emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++
- "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++
- ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
- return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
+ emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++
+ "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++
+ ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
+ return (CEStruct strname
+ [("buf", CEAddrOf (CELit tldname))
+ ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))])
EBuild _ n esh efun -> do
shname <- compileAssign "sh" env esh
@@ -770,7 +782,7 @@ compile' env = \case
emit $ SBlock $
pure (SVarDecl False "size_t" linivar (CELit "0"))
<> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
- (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx))))
+ (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx))))
| (ivar, dimidx) <- zip ivars [0::Int ..]]
(pure (SVarDecl True (repSTy (typeOf esh)) idxargname
(shapeTupFromLitVars n ivars))
@@ -799,7 +811,7 @@ compile' env = \case
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
@@ -845,7 +857,7 @@ compile' env = \case
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
let vecwid = 8 :: Int
ivar <- genName' "i"
@@ -909,6 +921,136 @@ compile' env = \case
EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
+ EReshape _ dim esh earg -> do
+ let STArr origDim eltty = typeOf earg
+ strname <- emitStruct (STArr dim eltty)
+
+ shname <- compileAssign "reshsh" env esh
+ arrname <- compileAssign "resharg" env earg
+
+ when emitChecks $ do
+ emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
+ printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
+ printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
+ mempty
+
+ return (CEStruct strname
+ [("buf", CEProj (CELit arrname) "buf")
+ ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+
+ EFold1InnerD1 _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+ STPair _ bty = typeOf efun
+
+ x0name <- compileAssign "foldd1x0" env ex0
+ arrname <- compileAssign "foldd1arr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
+ storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "tot"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
+ arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd1x0" Decrement t x0name
+ incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
+
+ strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
+ return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
+
+ EFold1InnerD2 _ commut efun estores ectg -> do
+ let STArr n t2 = typeOf ectg
+ STArr _ bty = typeOf estores
+
+ storesname <- compileAssign "foldd2stores" env estores
+ ctgname <- compileAssign "foldd2ctg" env ectg
+
+ zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
+ outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "acc"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
+ storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
+ ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
+ ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
+ <> ctgeltIncrStmts
+ -- we need to loop in reverse here, but we let jvar run in the
+ -- forward direction so that we can use SLoop. Note jvar is
+ -- reversed in eltidx above
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the accumulator
+ -- and the stores element. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the stores element.
+ storeseltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
+ incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
+
+ strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
+ return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
+
EConst _ t x -> return $ CELit $ compileScal True t x
EIdx0 _ e -> do
@@ -934,7 +1076,7 @@ compile' env = \case
when emitChecks $
forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
- (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]")))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
arrname ++ ".buf); return false;")
@@ -977,6 +1119,8 @@ compile' env = \case
maybe (return ()) ($ name2) mfun2
return (CELit name)
+ ERecompute _ e -> compile' env e
+
EWith _ t e1 e2 -> do
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
@@ -1000,95 +1144,7 @@ compile' env = \case
rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
- EAccum _ t prj eidx eval eacc -> do
- let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
- -- full zero array with the given zero info (for the type SMTArr n t1).
- initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
- initZeroArray n t1 v vzi = do
- shszname <- genName' "inacshsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi)
- newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi)
- emit $ SAsg v (CELit newarrName)
- forM_ (initZeroFromMemset t1) $ \f1 -> do
- ivar <- genName' "i"
- ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts
-
- -- If something needs to be done to properly initialise this type to
- -- zero after memory has already been initialised to all-zero bytes,
- -- returns an action that does so.
- -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ())
- initZeroFromMemset SMTNil = Nothing
- initZeroFromMemset (SMTPair t1 t2) =
- case (initZeroFromMemset t1, initZeroFromMemset t2) of
- (Nothing, Nothing) -> Nothing
- (mf1, mf2) -> Just $ \v vzi -> do
- forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a")
- forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b")
- initZeroFromMemset SMTLEither{} = Nothing
- initZeroFromMemset SMTMaybe{} = Nothing
- initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
- initZeroFromMemset SMTScal{} = Nothing
-
- let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroZI :: SMTy a -> String -> String -> CompM ()
- initZeroZI SMTNil _ _ = return ()
- initZeroZI (SMTPair t1 t2) v vzi = do
- initZeroZI t1 (v++".a") (vzi++".a")
- initZeroZI t2 (v++".b") (vzi++".b")
- initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZeroZI (SMTScal sty) v _ = case sty of
- STI32 -> emit $ SAsg v (CELit "0")
- STI64 -> emit $ SAsg v (CELit "0l")
- STF32 -> emit $ SAsg v (CELit "0.0f")
- STF64 -> emit $ SAsg v (CELit "0.0")
-
- let -- Initialise an uninitialised accumulation value, potentially already
- -- with the addend, potentially to zero depending on the nature of the
- -- projection.
- -- 1. If the projection indexes only through dense monoids before
- -- reaching SAPHere, the thing cannot be initialised to zero with
- -- only an AcIdx; it would need to model a zero after the addend,
- -- which is stupid and redundant. In this case, we return Left:
- -- (accumulation value) (AcIdx value) (addend value).
- -- The addend is copied, not consumed. (We can't reliably _always_
- -- consume it, so it's not worth trying to do it sometimes.)
- -- 2. Otherwise, a sparse monoid is found along the way, and we can
- -- initalise the dense prefix of the path to zero by setting the
- -- indexed-through sparse value to a sparse zero. Afterwards, the
- -- main recursion can proceed further. In this case, we return
- -- Right: (accumulation value) (AcIdx value)
- -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
- initZeroChunk :: SMTy a -> SAcPrj p a b
- -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
- (String -> String -> CompM ()) -- zero initialisation of sparse chunk
- initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
- -- reached target before the first sparse constructor
- (t1 , SAPHere ) -> Left $ \v _ addend -> do
- incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
- emit $ SAsg v (CELit addend)
- -- sparse types
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- -- dense types
- (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- f (v++".a") (i++".a")
- initZeroZI t2 (v++".b") (i++".b")
- (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
- initZeroZI t1 (v++".a") (i++".a")
- f (v++".b") (i++".b")
- (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- initZeroArray n t1 v (i++".a.b")
- linidxvar <- genName' "li"
- emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
- f (v++".buf->xs["++linidxvar++"]") (i++".b")
- where
- applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
- applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
-
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
let -- Add a value (s) into an existing accumulation value (d). If a sparse
-- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
@@ -1129,16 +1185,16 @@ compile' env = \case
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
forM_ [0 .. fromSNat n - 1] $ \j -> do
- emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
+ emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]"))
"!="
- (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
+ (CELit (d ++ ".sh[" ++ show j ++ "]")))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
"dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
d ++ ".buf" ++
- concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
", " ++ s ++ ".buf" ++
- concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
"); " ++
"return false;")
mempty
@@ -1158,67 +1214,55 @@ compile' env = \case
accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
- accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
- accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
- accumRef (SMTLEither ta tb) prj0 v i addend = do
- let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
- SAPRight prj' -> initZeroChunk tb prj'
- subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
- tagval = case prj0 of SAPLeft{} -> "1"
- SAPRight{} -> "2"
- ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
- SAPRight prj' -> accumRef tb prj' subv i addend
- case chunkres of
- Left densef -> do
- ((), stmtsSet) <- scope $ densef subv i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
- stmtsAdd -- TODO: emit check for consistency of tags?
- Right sparsef -> do
- ((), stmtsInit) <- scope $ sparsef subv i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
- forM_ stmtsAdd emit
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
- case initZeroChunk tj prj' of
- Left densef -> do
- ((), stmtsSet1) <- scope $ densef (v++".j") i addend
- ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
- stmtsAdd1
- Right sparsef -> do
- ((), stmtsInit1) <- scope $ sparsef (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i addend
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
"); " ++
"return false;")
mempty
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
nameidx <- compileAssign "acidx" env eidx
nameval <- compileAssign "acval" env eval
@@ -1232,6 +1276,9 @@ compile' env = \case
return $ CEStruct (repSTy STNil) []
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
EError _ t s -> do
let padleft len c s' = replicate (len - length s) c ++ s'
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
@@ -1245,6 +1292,7 @@ compile' env = \case
return $ CEStruct name []
EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
@@ -1303,13 +1351,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
-makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
- (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
@@ -1361,21 +1409,21 @@ toLinearIdx SZ _ _ = CELit "0"
toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
toLinearIdx (SS n) arrvar idxvar =
CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
- "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n)))))
+ "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n)))))
"+" (CELit (idxvar ++ ".b"))
-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
-- fromLinearIdx (SS n) arrvar idxvar = do
-- name <- genName
--- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
+-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]")))
-- _
data AllocMethod = Malloc | Calloc
deriving (Show)
-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
allocArray marker method nameBase rank eltty mshsz shape = do
when (length shape /= fromSNat rank) $
error "allocArray: shape does not match rank"
@@ -1390,9 +1438,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do
(CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
emit $ SVarDecl True strname arrname $ CEStruct strname
[("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
- Calloc -> CECall "calloc_instr" [nbytesExpr])]
- forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
- emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
+ Calloc -> CECall "calloc_instr" [nbytesExpr])
+ ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))]
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
when debugRefc $
emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
@@ -1403,16 +1450,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
compileShapeQuery (SS n) var =
CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
[("a", compileShapeQuery n var)
- ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
+ ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))]
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
-- | Takes a variable name for the array, not the buffer.
compileArrShapeComponents :: SNat n -> String -> [CExpr]
compileArrShapeComponents n var =
- [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
indexTupleComponents :: SNat n -> String -> [CExpr]
indexTupleComponents = \n var -> map CELit (toList (go n var))
@@ -1431,6 +1478,9 @@ shapeTupFromLitVars = \n -> go n . reverse
go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+prodExpr :: [CExpr] -> CExpr
+prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
+
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
let unary cop = return @CompM $ CECall cop [e1]
@@ -1503,7 +1553,7 @@ compileExtremum nameBase opName operator env e = do
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
@@ -1574,7 +1624,7 @@ copyForWriting topty var = case topty of
-- nesting we'd have to check the refcounts of all the nested arrays _too_;
-- let's not do that. Furthermore, no sub-arrays means that the whole thing
-- is flat, and we can just memcpy if necessary.
- SMTArr n t | not (hasArrays (fromSMTy t)) -> do
+ SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do
name <- genName
shszname <- genName' "shsz"
emit $ SVarDeclUninit toptyname name
@@ -1583,7 +1633,7 @@ copyForWriting topty var = case topty of
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
emit $ SVerbatim $
"fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
- concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
");"
emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
@@ -1594,8 +1644,7 @@ copyForWriting topty var = case topty of
in BList
[SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
+ ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
,SAsg (name ++ ".buf->refc") (CELit "1")
,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
printCExpr 0 databytes ");"])
@@ -1612,8 +1661,7 @@ copyForWriting topty var = case topty of
name <- genName
emit $ SVarDecl False toptyname name
(CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
+ emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-- put the arrays in variables to cut short the not-quite-var chain
@@ -1657,6 +1705,14 @@ zeroRefcountCheck toptyp opname topvar =
go (STEither a b) path = do
(s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
go (STMaybe a) path = do
ss <- go a (path++".j")
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
@@ -1673,14 +1729,6 @@ zeroRefcountCheck toptyp opname topvar =
return (BList [s1, s2, s3])
go STScal{} _ = empty
go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
- go (STLEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $
- SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
- s1
- (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
- s2
- mempty))
combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
combine (MaybeT a) (MaybeT b) = MaybeT $ do
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index d708fc0..cc6d5fa 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -11,8 +11,6 @@ module Compile.Exec (
import Control.Monad (when)
import Data.IORef
-import qualified Data.Map.Strict as Map
-import Data.Map.Strict (Map)
import Foreign (Ptr)
import Foreign.Ptr (FunPtr)
import System.Directory (removeDirectoryRecursive)
@@ -30,10 +28,10 @@ debug :: Bool
debug = False
-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)
-data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ()))))
+data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ())))
-buildKernel :: String -> [String] -> IO KernelLib
-buildKernel csource funnames = do
+buildKernel :: String -> String -> IO KernelLib
+buildKernel csource funname = do
template <- (++ "/tmp.chad.") <$> getTempDir
path <- mkdtemp template
@@ -44,7 +42,8 @@ buildKernel csource funnames = do
,"-o", outso, "-"
,"-Wall", "-Wextra"
,"-Wno-unused-variable", "-Wno-unused-but-set-variable"
- ,"-Wno-unused-parameter", "-Wno-unused-function"]
+ ,"-Wno-unused-parameter", "-Wno-unused-function"
+ ,"-Wno-alloc-size-larger-than"] -- ideally we'd keep this, but gcc reports false positives
(ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource
-- Print the source before the GCC output.
@@ -69,8 +68,7 @@ buildKernel csource funnames = do
removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now
- ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames]
- ref <- newIORef ptrs
+ ref <- newIORef =<< dlsym dl funname
_ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1))
when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"
dlclose dl)
@@ -81,10 +79,10 @@ foreign import ccall "dynamic"
-- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive
{-# NOINLINE callKernelFun #-}
-callKernelFun :: String -> KernelLib -> Ptr () -> IO ()
-callKernelFun key (KernelLib ref) arg = do
- mp <- readIORef ref
- wrapKernelFun (mp Map.! key) arg
+callKernelFun :: KernelLib -> Ptr () -> IO ()
+callKernelFun (KernelLib ref) arg = do
+ ptr <- readIORef ref
+ wrapKernelFun ptr arg
getTempDir :: IO FilePath
getTempDir =
diff --git a/src/Data.hs b/src/Data.hs
index e86aaa6..e6978c8 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -8,12 +8,13 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module Data (module Data, (:~:)(Refl)) where
+module Data (module Data, (:~:)(Refl), If) where
import Data.Functor.Product
import Data.GADT.Compare
import Data.GADT.Show
import Data.Some
+import Data.Type.Bool (If)
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
@@ -184,3 +185,8 @@ instance Applicative Bag where
instance Semigroup (Bag t) where (<>) = BTwo
instance Monoid (Bag t) where mempty = BNone
+
+data SBool b where
+ SF :: SBool False
+ ST :: SBool True
+deriving instance Show (SBool b)
diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs
index 9c10421..2712b08 100644
--- a/src/Data/VarMap.hs
+++ b/src/Data/VarMap.hs
@@ -74,7 +74,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env'
subMap subenv =
let bools = let loop :: Subenv env env' -> [Bool]
loop SETop = []
- loop (SEYes sub) = True : loop sub
+ loop (SEYesR sub) = True : loop sub
loop (SENo sub) = False : loop sub
in VS.fromList $ loop subenv
newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools
@@ -89,7 +89,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env
superMap subenv =
let loop :: Subenv env env' -> Int -> [Int]
loop SETop _ = []
- loop (SEYes sub) i = i : loop sub (i+1)
+ loop (SEYesR sub) i = i : loop sub (i+1)
loop (SENo sub) i = loop sub (i+1)
newIndices = VS.fromList $ loop subenv 0
diff --git a/src/Example.hs b/src/Example.hs
index 3623d03..2c51291 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -5,13 +5,18 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
+
+{-# OPTIONS -Wno-unused-imports #-}
module Example where
import Array
import AST
+import AST.Count
import AST.Pretty
+import AST.UnMonoid
import CHAD
import CHAD.Top
+import CHAD.Types
import ForwardAD
import Interpreter
import Language
@@ -24,6 +29,22 @@ import Example.Types
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
+pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+pipeline config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ simplifyFix $ pruneExpr knownEnv $
+ simplifyFix $ unMonoid $
+ chad' config knownEnv $
+ simplifyFix $
+ term
+
+-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures
+pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO ()
+pipeline' config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ pprintExpr (pipeline config term)
+
+
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)
@@ -159,8 +180,18 @@ neuralGo =
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
- (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
- _ -> undefined
+ (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
+
+-- The build body uses free variables in a non-linear way, so their primal
+-- values are required in the dual of the build. Thus, compositionally, they
+-- are stored in the tape from each individual lambda invocation. This results
+-- in n copies of y and z, where only one copy would have sufficed.
+exUniformFree :: Ex '[R, I64] R
+exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $
+ let_ #y (#x * 2) $
+ let_ #z (#x * 3) $
+ idx0 $ sum1i $
+ build1 #n $ #i :-> #y * #z + toFloat_ #i
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 5756f96..b353def 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -26,10 +26,10 @@ type family Tan t where
Tan TNil = TNil
Tan (TPair a b) = TPair (Tan a) (Tan b)
Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
- Tan (TLEither a b) = TLEither (Tan a) (Tan b)
type family TanS t where
TanS TI32 = TNil
@@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t)
tanty STNil = STNil
tanty (STPair a b) = STPair (tanty a) (tanty b)
tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
tanty (STMaybe t) = STMaybe (tanty t)
tanty (STArr n t) = STArr n (tanty t)
tanty (STScal t) = case t of
@@ -55,7 +56,6 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
-tanty (STLEither a b) = STLEither (tanty a) (tanty b)
tanenv :: SList STy env -> SList STy (TanE env)
tanenv SNil = SNil
@@ -66,6 +66,9 @@ zeroTan STNil () = ()
zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
zeroTan (STMaybe _) Nothing = Nothing
zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
zeroTan (STArr _ t) x = fmap (zeroTan t) x
@@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
-zeroTan (STLEither _ _) Nothing = Nothing
-zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
-zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
tanScalars (STEither a _) (Left x) = tanScalars a x
tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanScalars (STMaybe _) Nothing = []
tanScalars (STMaybe t) (Just x) = tanScalars t x
tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
@@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
-tanScalars (STLEither _ _) Nothing = []
-tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
-tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) =
unzipDN (STEither a b) d = case d of
Left d1 -> bimap Left Left (unzipDN a d1)
Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
unzipDN (STMaybe t) d = case d of
Nothing -> (Nothing, Nothing)
Just d' -> bimap Just Just (unzipDN t d')
@@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
-unzipDN (STLEither a b) d = case d of
- Nothing -> (Nothing, Nothing)
- Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
- Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of
(Left x', Left y') -> dotprodTan a x' y'
(Right x', Right y') -> dotprodTan b x' y'
_ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
dotprodTan (STMaybe t) x y = case (x, y) of
(Nothing, Nothing) -> 0.0
(Just x', Just y') -> dotprodTan t x' y'
@@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
-dotprodTan (STLEither a b) x y = case (x, y) of
- (Nothing, _) -> 0.0 -- 0 * y = 0
- (_, Nothing) -> 0.0 -- x * 0 = 0
- (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
- (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
- _ -> error "dotprodTan: incompatible LEither alternatives"
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t)
dnConst STNil = const ()
dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
dnConst (STMaybe t) = fmap (dnConst t)
dnConst (STArr _ t) = arrayMap (dnConst t)
dnConst (STScal t) = case t of
@@ -196,7 +197,6 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
-dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e =
case e of
Left x -> \f -> Left (dnOnehots t1 x (f . Left))
Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnOnehots (STMaybe t) m =
case m of
Nothing -> \_ -> Nothing
@@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
-dnOnehots (STLEither t1 t2) e =
- case e of
- Nothing -> \_ -> Nothing
- Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
- Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index ebc70d7..44bdbb2 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -168,6 +168,8 @@ dfwdDN = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EReshape _ n esh e
+ | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
@@ -185,16 +187,22 @@ dfwdDN = \case
ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
+ ERecompute _ e -> dfwdDN e
EError _ t s -> EError ext (dn t) s
EWith{} -> err_accum
EAccum{} -> err_accum
+ EDeepZero{} -> err_monoid
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
+
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
+ err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program"
deriv_extremum :: ScalIsNumeric t ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
@@ -218,4 +226,4 @@ dfwdDN = \case
(EFst ext (EVar ext tIxN (IS IZ)))))))
(ESnd ext (EVar ext t2 (IS IZ)))
(zeroScalarConst t))))
- (EMaximum1Inner ext (dfwdDN e))
+ (extremum (dfwdDN e))
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
index 3c76cbe..dcacf5f 100644
--- a/src/ForwardAD/DualNumbers/Types.hs
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -12,10 +12,10 @@ type family DN t where
DN TNil = TNil
DN (TPair a b) = TPair (DN a) (DN b)
DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TLEither a b) = TLEither (DN a) (DN b)
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
- DN (TLEither a b) = TLEither (DN a) (DN b)
type family DNS t where
DNS TF32 = TPair (TScal TF32) (TScal TF32)
@@ -32,6 +32,7 @@ dn :: STy t -> STy (DN t)
dn STNil = STNil
dn (STPair a b) = STPair (dn a) (dn b)
dn (STEither a b) = STEither (dn a) (dn b)
+dn (STLEither a b) = STLEither (dn a) (dn b)
dn (STMaybe t) = STMaybe (dn t)
dn (STArr n t) = STArr n (dn t)
dn (STScal t) = case t of
@@ -41,7 +42,6 @@ dn (STScal t) = case t of
STI64 -> STScal STI64
STBool -> STScal STBool
dn STAccum{} = error "Accum in source program"
-dn (STLEither a b) = STLEither (dn a) (dn b)
dne :: SList STy env -> SList STy (DNE env)
dne SNil = SNil
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index d7916d8..9e3d2a6 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,12 +21,16 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict (runStateT, get, put)
+import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
+import Data.Tuple (swap)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
@@ -35,6 +39,7 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
+import AST.Sparse.Types
import Data
import Interpreter.Rep
@@ -141,6 +146,43 @@ interpret'Rec env = \case
sh `ShCons` n = arrayShape arr
numericIsNum t $ return $
arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EReshape _ dim esh e -> do
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh
+ arr <- interpret' env e
+ return $ arrayReshape sh arr
+ EFold1InnerD1 _ _ a b c -> do
+ let t = typeOf b
+ let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ -- TODO: this is very inefficient, even for an interpreter; with mutable
+ -- arrays this can be a lot better with no lists
+ res <- arrayGenerateM sh $ \idx -> do
+ (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ return (y, arrayFromList (ShNil `ShCons` n) stores)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EFold1InnerD2 _ _ ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef
+ bog <- interpret' env ebog
+ arrctg <- interpret' env ed
+ let sh `ShCons` n = arrayShape bog
+ when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2"
+ res <- arrayGenerateM sh $ \idx -> do
+ let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
+ loop i !ctg !inpctgs = do
+ let b = arrayIndex bog (idx `IxCons` i)
+ (ctg1, ctg2) <- f b ctg
+ loop (i - 1) ctg1 (ctg2 : inpctgs)
+ (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
+ return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
@@ -153,18 +195,22 @@ interpret'Rec env = \case
e1' <- interpret' env e1
e2' <- interpret' env e2
interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
+ ERecompute _ e -> interpret' env e
EWith _ t e1 e2 -> do
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
interpret' (V (STAccum t) accum `SCons` env) e2
- EAccum _ t p e1 e2 e3 -> do
+ EAccum _ t p e1 sp e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t p accum idx val
+ accumAddSparseD t p accum idx sp val
EZero _ t ezi -> do
zi <- interpret' env ezi
return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
@@ -215,6 +261,19 @@ zeroM typ zi = case typ of
STF32 -> 0.0
STF64 -> 0.0
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
addM :: SMTy t -> Rep t -> Rep t -> Rep t
addM typ a b = case typ of
SMTNil -> ()
@@ -238,7 +297,7 @@ addM typ a b = case typ of
| otherwise -> error "Plus of inconsistently shaped arrays"
SMTScal sty -> numericIsNum sty $ a + b
-onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
onehotM SAPHere _ _ val = val
onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
@@ -255,15 +314,6 @@ withAccum t _ initval f = AcM $ do
val <- readAc t accum
return (out, val)
-newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t)
-newAcZero typ zi = case typ of
- SMTNil -> return ()
- SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi
- SMTLEither{} -> newIORef Nothing
- SMTMaybe _ -> newIORef Nothing
- SMTArr _ t -> arrayMapM (newAcZero t) zi
- SMTScal sty -> numericIsNum sty $ newIORef 0
-
newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
newAcDense typ val = case typ of
SMTNil -> return ()
@@ -273,26 +323,10 @@ newAcDense typ val = case typ of
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
-newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a)
-newAcSparse typ prj idx val = case (typ, prj) of
- (_, SAPHere) -> newAcDense typ val
-
- (SMTPair t1 t2, SAPFst prj') ->
- (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx)
- (SMTPair t1 t2, SAPSnd prj') ->
- (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val
-
- (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
- (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
-
- (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
-
- (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx
-
onehotArray :: Monad m
- => (Rep (AcIdx p a) -> m v) -- ^ the "one"
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
-> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
- -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = arrayShape ziarr
@@ -308,54 +342,67 @@ readAc typ val = case typ of
SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
-accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s ()
-accumAddDense typ ref val = case typ of
- SMTNil -> return ()
- SMTPair t1 t2 -> do
- accumAddDense t1 (fst ref) (fst val)
- accumAddDense t2 (snd ref) (snd val)
- SMTLEither{} ->
- case val of
- Nothing -> return ()
- Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
- Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- SMTMaybe{} ->
- case val of
- Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- SMTArr _ t1 ->
- forM_ [0 .. arraySize ref - 1] $ \i ->
- accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
- SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
-
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (_, SAPHere) -> accumAddDense typ ref val
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
- (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val
- (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
(SMTLEither t1 _, SAPLeft prj') ->
- realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
- (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
- Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
(SMTLEither _ t2, SAPRight prj') ->
- realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
- (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
- Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
(SMTMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
(SMTArr n t1, SAPArrIdx prj') ->
- let ((arrindex', ziarr), idx') = idx
+ let (arrindex', idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = arrayShape ziarr
+ arrsh = arrayShape ref
linindex = toLinearIndex arrsh arrindex
- in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ case val of
+ Nothing -> return ()
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
-- Try modifying what's already in ref. The 'join' makes the snd
@@ -404,3 +451,11 @@ ixUncons (IxCons idx i) = (idx, i)
shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)
+
+mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
+mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f'
+ where f' x = do
+ s <- get
+ (s', y) <- lift $ f s x
+ put s'
+ return y
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 070ba4c..1682303 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -22,11 +22,11 @@ type family Rep t where
Rep TNil = ()
Rep (TPair a b) = (Rep a, Rep b)
Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))
Rep (TMaybe t) = Maybe (Rep t)
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
Rep (TAccum t) = RepAc t
- Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))
-- Mutable, represents monoid types t.
type family RepAc t where
@@ -56,6 +56,9 @@ showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y
+showValue _ (STLEither _ _) Nothing = showString "LNil"
+showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
+showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
showValue d (STArr _ t) arr = showParen (d > 10) $
@@ -70,9 +73,6 @@ showValue d (STScal sty) x = case sty of
STI64 -> showsPrec d x
STBool -> showsPrec d x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
-showValue _ (STLEither _ _) Nothing = showString "LNil"
-showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
-showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
@@ -86,6 +86,9 @@ rnfRep STNil () = ()
rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y
rnfRep (STEither a _) (Left x) = rnfRep a x
rnfRep (STEither _ b) (Right y) = rnfRep b y
+rnfRep (STLEither _ _) Nothing = ()
+rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
+rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
rnfRep (STMaybe _) Nothing = ()
rnfRep (STMaybe t) (Just x) = rnfRep t x
rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr =
@@ -97,9 +100,6 @@ rnfRep (STScal t) x = case t of
STF64 -> rnf x
STBool -> rnf x
rnfRep STAccum{} _ = error "Cannot rnf accumulators"
-rnfRep (STLEither _ _) Nothing = ()
-rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
-rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
instance KnownTy t => NFData (Value t) where
rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/Language.hs b/src/Language.hs
index 9fd5dd3..31b4b87 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -17,6 +17,7 @@ module Language (
import Array
import AST
+import AST.Sparse.Types
import AST.Types
import CHAD.Types
import Data
@@ -129,6 +130,17 @@ maximum1i e = NEMaximum1Inner e
minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
minimum1i e = NEMinimum1Inner e
+reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+reshape = NEReshape
+
+fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3
+
+fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
+ -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
+
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
@@ -169,11 +181,17 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+recompute :: NExpr env a -> NExpr env a
+recompute = NERecompute
+
with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
with a (n :-> b) = NEWith (knownMTy @t) a n b
-accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
-accum p a b c = NEAccum knownMTy p a b c
+accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+
+accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+accumS p a sp b c = NEAccum knownMTy p a sp b c
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 8bcb5e5..c9d05c9 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM
import Array
import AST
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -55,6 +56,16 @@ data NExpr env t where
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+
+ NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b)
+ -> NExpr env t1
+ -> NExpr env (TArr (S n) t1)
+ -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+ NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
+ -> NExpr env (TArr (S n) b)
+ -> NExpr env (TArr n t2)
+ -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
@@ -71,9 +82,12 @@ data NExpr env t where
-> NExpr env a -> NExpr env b
-> NExpr env t
+ -- fake halfway checkpointing
+ NERecompute :: NExpr env t -> NExpr env t
+
-- accumulation effect on monoids
NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
- NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
-- partiality
NEError :: STy a -> String -> NExpr env a
@@ -201,6 +215,10 @@ fromNamedExpr val = \case
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
NEMaximum1Inner e -> EMaximum1Inner ext (go e)
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
+ NEReshape n a b -> EReshape ext n (go a) (go b)
+
+ NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
@@ -215,9 +233,10 @@ fromNamedExpr val = \case
(fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
(fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
(go e1) (go e2)
+ NERecompute e -> ERecompute ext (go e)
NEWith t a n b -> EWith ext t (go a) (lambda val n b)
- NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c)
+ NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
NEError t s -> EError ext t s
diff --git a/src/Simplify.hs b/src/Simplify.hs
index f5eb0a1..b89d7f6 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,7 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE QuasiQuotes #-}
@@ -19,13 +21,14 @@ import Control.Monad (ap)
import Data.Bifunctor (first)
import Data.Function (fix)
import Data.Monoid (Any(..))
-import Data.Type.Equality (testEquality)
import Debug.Trace
import AST
import AST.Count
import AST.Pretty
+import AST.Sparse.Types
+import AST.UnMonoid (acPrjCompose)
import Data
import Simplify.TH
@@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id)
smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
smReconstruct core = SM (\ctx -> (Any False, ctx core))
-tellActed :: SM tenv tt env t ()
-tellActed = SM (\_ -> (Any True, ()))
+class Monad m => ActedMonad m where
+ tellActed :: m ()
+ hideActed :: m a -> m a
+ liftActed :: (Any, a) -> m a
+
+instance ActedMonad ((,) Any) where
+ tellActed = (Any True, ())
+ hideActed (_, x) = (Any False, x)
+ liftActed = id
+
+instance ActedMonad (SM tenv tt env t) where
+ tellActed = SM (\_ -> tellActed)
+ hideActed (SM f) = SM (\ctx -> hideActed (f ctx))
+ liftActed pair = SM (\_ -> pair)
-- more convenient in practice
-acted :: SM tenv tt env t a -> SM tenv tt env t a
+acted :: ActedMonad m => m a -> m a
acted m = tellActed >> m
within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
-acted' :: (Any, a) -> (Any, a)
-acted' (_, x) = (Any True, x)
-
-liftActed :: (Any, a) -> SM tenv tt env t a
-liftActed pair = SM $ \_ -> pair
-
simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
simplify' expr
| scLogging ?config = do
@@ -167,10 +176,10 @@ simplify'Rec = \case
ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
- EAccum _ t p e1 (ELet _ rhs body) acc ->
+ EAccum _ t p e1 sp (ELet _ rhs body) acc ->
acted $ simplify' $
ELet ext rhs $
- EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc)
+ EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc)
-- let () = e in () ~> e
ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
@@ -183,10 +192,20 @@ simplify'Rec = \case
ESnd _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
+ EFst _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (EFst ext e1) (EFst ext e2) e3
+ ESnd _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
-- TODO: more array indexing
- EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
- EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
+ EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet e3 e2
+ EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
+ EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1
+
+ -- TODO: more array shape
+ EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1
-- TODO: more constant folding
EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
@@ -216,23 +235,40 @@ simplify'Rec = \case
acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
- EAccum _ t p e1 e2 acc -> do
- e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1
- e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2
- acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc
- simplifyOneHotTerm (OneHotTerm t p e1' e2')
+ EAccum _ t p e1 sp e2 acc -> do
+ e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc
+ simplifyOHT (OneHotTerm SAID t p e1' sp e2')
(acted $ return (ENil ext))
- (\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
+ (\sp' (InContext w wrap e) -> do
+ e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e
+ return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')))
+ (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do
+ -- The acted management here is a hideous mess.
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2''
+ return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')))
EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
- simplifyOneHotTerm (OneHotTerm t p e1' e2')
+ simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2')
(acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
- (\e -> acted $ return e)
- (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
+ (\sp' (InContext _ wrap e) ->
+ case isDense t sp' of
+ Just Refl -> do
+ e' <- hideActed $ within wrap $ simplify' e
+ return (wrap e')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
+ (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) ->
+ case isDense (acPrjTy p' t') sp' of
+ Just Refl -> do
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2''
+ return (wrap $ EOneHot ext t' p' e1''' e2''')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
-- type-specific equations for plus
EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
@@ -278,6 +314,9 @@ simplify'Rec = \case
EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
+ EReshape _ n a b -> [simprec| EReshape ext n *a *b |]
+ EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
+ EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]
EConst _ t v -> pure $ EConst ext t v
EIdx0 _ e -> [simprec| EIdx0 ext *e |]
EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
@@ -291,24 +330,18 @@ simplify'Rec = \case
e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
pure (ECustom ext s t p a' b' c' e1' e2')
+ ERecompute _ e -> [simprec| ERecompute ext *e |]
EWith _ t e1 e2 -> do
e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
pure (EWith ext t e1' e2')
- EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e
- EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b
+ -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |]
+ -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |]
+ EZero _ t e -> [simprec| EZero ext t *e |]
+ EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
-cheapExpr :: Expr x env t -> Bool
-cheapExpr = \case
- EVar{} -> True
- ENil{} -> True
- EConst{} -> True
- EFst _ e -> cheapExpr e
- ESnd _ e -> cheapExpr e
- EUnit _ e -> cheapExpr e
- _ -> False
-
-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
hasAdds :: Expr x env t -> Bool
@@ -337,6 +370,9 @@ hasAdds = \case
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
EMaximum1Inner _ e -> hasAdds e
EMinimum1Inner _ e -> hasAdds e
+ EReshape _ _ a b -> hasAdds a || hasAdds b
+ EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
@@ -345,8 +381,10 @@ hasAdds = \case
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
EWith _ _ a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ _ _ -> True
+ ERecompute _ e -> hasAdds e
+ EAccum _ _ _ _ _ _ _ -> True
EZero _ _ e -> hasAdds e
+ EDeepZero _ _ e -> hasAdds e
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
@@ -359,57 +397,167 @@ checkAccumInScope = \case SNil -> False
check STNil = False
check (STPair s t) = check s || check t
check (STEither s t) = check s || check t
+ check (STLEither s t) = check s || check t
check (STMaybe t) = check t
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
- check (STLEither s t) = check s || check t
-data OneHotTerm env p a b where
- OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b
-deriving instance Show (OneHotTerm env p a b)
+data OneHotTerm dense env a where
+ OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a
+deriving instance Show (OneHotTerm dense env a)
+
+data InContext f env (a :: Ty) where
+ InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a
+
+simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do
+ val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val
+ return $ OneHotTerm dense t prj idx sp val'
+
+simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a)
+simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) =
+ unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 ->
+ acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' ->
+ return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2)
+simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht
-simplifyOneHotTerm :: OneHotTerm env p a b
- -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero)
- -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r)
- -> SM tenv tt env t r
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do
- val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1
- case val1' of
- EZero{} -> kzero
- EOneHot _ t2 prj2 idx2 val2
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do
- tellActed -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k
- _ -> case prj1 of
- SAPHere -> ktriv val1
- _ -> k (OneHotTerm t1 prj1 idx1 val1)
+simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val))
+ | Just Refl <- isDense (acPrjTy prj1 t1) sp =
+ let idx2' :: Ex env (AcIdx dense p2 c)
+ idx2' = case dense of
+ SAID -> reduceAcIdx t2 prj2 idx2
+ SAIS -> idx2
+ in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' ->
+ acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val
+simplifyOHT_concat oht = return oht
+
+-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is
+-- -- dense, then the Sparse in the output will also be dense. This property is
+-- -- used when simplifying EOneHot, which cannot represent sparsity.
+simplifyOHT :: ActedMonad m => OneHotTerm dense env a
+ -> m r -- ^ Zero case (onehot is actually zero)
+ -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot)
+ -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified
+ -> m r
+simplifyOHT oht kzero ktriv k = do
+ -- traceM $ "sOHT: input " ++ show oht
+ oht1 <- simplifyOHT_recogniseMonoid oht
+ -- traceM $ "sOHT: recog " ++ show oht1
+ InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1
+ -- traceM $ "sOHT: unspa " ++ show oht2
+ oht3 <- simplifyOHT_concat oht2
+ -- traceM $ "sOHT: conca " ++ show oht3
+ -- traceM ""
+ case oht3 of
+ OneHotTerm _ _ _ _ _ EZero{} -> kzero
+ OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val)
+ _ -> k (InContext w1 wrap1 oht3)
+
+-- Sets the acted flag whenever a non-trivial projection is returned or the
+-- output Sparse is different from the input Sparse.
+unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a'
+ -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s)
+ -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r
+unsparseOneHotD topsp topval k = case (topsp, topval) of
+ -- eliminate always-Just sparse onehot
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k
+
+ -- expand the top levels of a onehot for a sparse type into a onehot for the
+ -- corresponding non-sparse type
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPFst spprj) idx' s1' e'
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPSnd spprj) idx' s1' e'
+ (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPLeft spprj) idx' s1' e'
+ (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPRight spprj) idx' s1' e'
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPJust spprj) idx' s1' e'
+ (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val)
+ | Dict <- styKnown (typeOf idx) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' ->
+ acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e'
+
+ -- anything else we don't know how to improve
+ _ -> k WId id SAPHere (ENil ext) topsp topval
+
+{-
+unsparseOneHotS :: ActedMonad m
+ => Sparse a a' -> Ex env a'
+ -> (forall b. Sparse a b -> Ex env b -> m r) -> m r
+unsparseOneHotS topsp topval k = case (topsp, topval) of
+ -- order is relevant to make sure we set the acted flag correctly
+ (SpAbsent, v@ENil{}) -> k SpAbsent v
+ (SpAbsent, v@EZero{}) -> k SpAbsent v
+ (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+
+ -- the unsparsifying
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k
+
+ -- recursion
+ -- TODO: coproducts could safely become projections as they do not need
+ -- zeroinfo. But that would only work if the coproduct is at the top, because
+ -- as soon as we hit a product, we need zeroinfo to make it a projection and
+ -- we don't have that.
+ (SpSparse s, e) -> k (SpSparse s) e
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' ->
+ acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext))
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' ->
+ acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do
+ case s2 of SpAbsent -> pure () ; _ -> tellActed
+ k (SpLEither s1' SpAbsent) (ELInl ext STNil e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do
+ case s1 of SpAbsent -> pure () ; _ -> tellActed
+ acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e')
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' ->
+ k (SpMaybe s1') (EJust ext e')
+ (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' ->
+ k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e')
+ _ -> _
+-}
-- | Recognises 'EZero' and 'EOneHot'.
recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
recogniseMonoid _ e@EOneHot{} = return e
-recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext)
+recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext)
recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
- (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2)
- (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
- (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
+ (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2)
+ (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
+ (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
(a', b') -> return $ EPair ext a' b'
recogniseMonoid typ@(SMTLEither t1 t2) expr =
case expr of
- ELNil{} -> acted' $ return $ EZero ext typ (ENil ext)
- ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
- ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
+ ELNil{} -> acted $ return $ EZero ext typ (ENil ext)
+ ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
_ -> return expr
recogniseMonoid typ@(SMTMaybe t1) expr =
case expr of
- ENothing{} -> acted' $ return $ EZero ext typ (ENil ext)
- EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ENothing{} -> acted $ return $ EZero ext typ (ENil ext)
+ EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
_ -> return expr
recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
- acted' $ do
+ acted $ do
e' <- recogniseMonoid t e
return $
ELet ext e' $
@@ -418,59 +566,33 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
(ENil ext))
(EVar ext (fromSMTy t) IZ)
recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
- (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext)
- (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext)
- (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext)
- (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext)
+ (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext)
_ -> return e
recogniseMonoid _ e = return e
-concatOneHots :: SMTy a
- -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
-concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
- (_, SAPHere) -> k prj2 idx2
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a)
+reduceAcIdx topty topprj e = case (topty, topprj) of
+ (_, SAPHere) -> ENil ext
+ (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
+ (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
+ (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e
+ (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e
+ (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e
+ (SMTArr _ t, SAPArrIdx p) ->
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (efst e1) (reduceAcIdx t p e2)
- (SMTPair a _, SAPFst prj1') ->
- concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ)))
- (SMTPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
-
- (SMTLEither a _, SAPLeft prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (SMTLEither _ b, SAPRight prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
-
- (SMTMaybe a, SAPJust prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
-
- (SMTArr _ a, SAPArrIdx prj1') ->
- concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
-
-zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
where
-- invariant: AcIdx expression is duplicable
- go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
go t SAPHere _ e = makeZeroInfo t e
go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
go SMTLEither{} _ _ _ = ENil ext
go SMTMaybe{} _ _ _ = ENil ext
go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)
-
-makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
-makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
- where
- -- invariant: expression argument is duplicable
- go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
- go SMTNil _ = ENil ext
- go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
- go SMTLEither{} _ = ENil ext
- go SMTMaybe{} _ = ENil ext
- go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
- go SMTScal{} _ = ENil ext
diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs
index 2e0076a..03a74de 100644
--- a/src/Simplify/TH.hs
+++ b/src/Simplify/TH.hs
@@ -3,7 +3,7 @@ module Simplify.TH (simprec) where
import Data.Bifunctor (first)
import Data.Char
-import Data.List (foldl1')
+import Data.List (foldl', foldl1')
import Language.Haskell.TH
import Language.Haskell.TH.Quote
import Text.ParserCombinators.ReadP
diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs
index e0dc4b3..5ceb866 100644
--- a/test-framework/Test/Framework.hs
+++ b/test-framework/Test/Framework.hs
@@ -2,11 +2,13 @@
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
module Test.Framework (
TestTree,
testGroup,
- testGroupCollapse,
+ groupSetCollapse,
testProperty,
withResource,
withResource',
@@ -18,43 +20,68 @@ module Test.Framework (
TestName,
) where
-import Control.Monad (forM, when)
-import Control.Monad.Trans.State.Strict
+import Control.Concurrent (setNumCapabilities, forkIO, killThread, forkOn)
+import Control.Concurrent.MVar
+import Control.Concurrent.STM
+import Control.Exception (finally)
+import Control.Monad (forM, when, forM_, replicateM_)
import Control.Monad.IO.Class
+import Data.IORef
import Data.List (isInfixOf, intercalate)
import Data.Maybe (isJust, mapMaybe, fromJust)
+import Data.PQueue.Prio.Min qualified as PQ
import Data.String (fromString)
import Data.Time.Clock
-import System.Environment
+import GHC.Conc (getNumProcessors)
+import System.Console.ANSI qualified as ANSI
+import System.Console.Concurrent (outputConcurrent)
+import System.Console.Regions
+import System.Environment (getArgs, getProgName)
import System.Exit
import System.IO (hFlush, hPutStrLn, stdout, stderr, hIsTerminalDevice)
import Text.Read (readMaybe)
-import qualified Hedgehog as H
-import qualified Hedgehog.Internal.Config as H
-import qualified Hedgehog.Internal.Property as H
-import qualified Hedgehog.Internal.Report as H
-import qualified Hedgehog.Internal.Runner as H
-import qualified Hedgehog.Internal.Seed as H.Seed
+import Hedgehog qualified as H
+import Hedgehog.Internal.Config qualified as H
+import Hedgehog.Internal.Property qualified as H
+import Hedgehog.Internal.Report qualified as H
+import Hedgehog.Internal.Runner qualified as H
+import Hedgehog.Internal.Seed qualified as H.Seed
+type TestName = String
+
data TestTree
- = Group Bool String [TestTree]
- | forall a. Resource (IO a) (a -> IO ()) (a -> TestTree)
+ = Group GroupOpts String [TestTree]
+ | forall a. Resource String (IO a) (a -> IO ()) (a -> TestTree)
+ -- ^ Name is not specified by user, but inherited from the tree below
| HP String H.Property
-type TestName = String
+-- Not exported because a Resource is not supposed to have a name in the first place
+treeName :: TestTree -> String
+treeName (Group _ name _) = name
+treeName (Resource name _ _ _) = name
+treeName (HP name _) = name
+
+data GroupOpts = GroupOpts
+ { goCollapse :: Bool }
+ deriving (Show)
+
+defaultGroupOpts :: GroupOpts
+defaultGroupOpts = GroupOpts False
testGroup :: String -> [TestTree] -> TestTree
-testGroup = Group False
+testGroup = Group defaultGroupOpts
-testGroupCollapse :: String -> [TestTree] -> TestTree
-testGroupCollapse = Group True
+groupSetCollapse :: TestTree -> TestTree
+groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name trees
+groupSetCollapse _ = error "groupSetCollapse: not called on a Group"
--- | The @a -> TestTree@ function must use the @a@ only inside properties: when
--- not actually running properties, it will be passed 'undefined'.
+-- | The @a -> TestTree@ function must use the @a@ only inside properties: the
+-- functoin will be passed 'undefined' when exploring the test tree (without
+-- running properties).
withResource :: IO a -> (a -> IO ()) -> (a -> TestTree) -> TestTree
-withResource = Resource
+withResource make cleanup fun = Resource (treeName (fun undefined)) make cleanup fun
-- | Same caveats as 'withResource'.
withResource' :: IO a -> (a -> TestTree) -> TestTree
@@ -66,14 +93,14 @@ testProperty = HP
filterTree :: Options -> TestTree -> Maybe TestTree
filterTree (Options { optsPattern = pat }) = go []
where
- go path (Group collapse name trees) =
+ go path (Group opts name trees) =
case mapMaybe (go (name:path)) trees of
[] -> Nothing
- trees' -> Just (Group collapse name trees')
- go path (Resource make free fun) =
+ trees' -> Just (Group opts name trees')
+ go path (Resource inhname make free fun) =
case go path (fun undefined) of
Nothing -> Nothing
- Just _ -> Just $ Resource make free (fromJust . go path . fun)
+ Just _ -> Just $ Resource inhname make free (fromJust . go path . fun)
go path hp@(HP name _)
| pat `isInfixOf` renderPath (name:path) = Just hp
| otherwise = Nothing
@@ -84,9 +111,12 @@ computeMaxLen :: TestTree -> Int
computeMaxLen = go 0
where
go :: Int -> TestTree -> Int
- go indent (Group True name trees) = maximum (2*indent + length name : map (go (indent+1)) trees)
- go indent (Group False _ trees) = maximum (0 : map (go (indent+1)) trees)
- go indent (Resource _ _ fun) = go indent (fun undefined)
+ go indent (Group opts name trees)
+ -- If we collapse, the name of the group gets prefixed before the final status message after collapsing.
+ | goCollapse opts = maximum (2*indent + length name : map (go (indent+1)) trees)
+ -- If we don't collapse, the group name does get printed but without any status message, so it doesn't need to get accounted for in maxlen.
+ | otherwise = maximum (0 : map (go (indent+1)) trees)
+ go indent (Resource _ _ _ fun) = go indent (fun undefined)
go indent (HP name _) = 2 * indent + length name
data Stats = Stats
@@ -97,22 +127,21 @@ data Stats = Stats
initStats :: Stats
initStats = Stats 0 0
-newtype M a = M (StateT Stats IO a)
- deriving newtype (Functor, Applicative, Monad, MonadIO)
-
-modifyStats :: (Stats -> Stats) -> M ()
-modifyStats f = M (modify f)
+modifyStats :: (?stats :: IORef Stats) => (Stats -> Stats) -> IO ()
+modifyStats f = atomicModifyIORef' ?stats (\s -> (f s, ()))
data Options = Options
{ optsPattern :: String
, optsHelp :: Bool
, optsHedgehogReplay :: Maybe (H.Skip, H.Seed)
, optsHedgehogShrinks :: Maybe Int
+ , optsParallel :: Bool
+ , optsVerbose :: Bool
}
deriving (Show)
defaultOptions :: Options
-defaultOptions = Options "" False Nothing Nothing
+defaultOptions = Options "" False Nothing Nothing False False
parseOptions :: [String] -> Options -> Either String Options
parseOptions [] opts = pure opts
@@ -134,6 +163,8 @@ parseOptions ("--hedgehog-shrinks":arg:args) opts =
case readMaybe arg of
Just n -> parseOptions args opts { optsHedgehogShrinks = Just n }
Nothing -> Left "Invalid argument to '--hedgehog-shrinks'"
+parseOptions ("--parallel":args) opts = parseOptions args opts { optsParallel = True }
+parseOptions ("--verbose":args) opts = parseOptions args opts { optsVerbose = True }
parseOptions (arg:_) _ = Left $ "Unrecognised argument: '" ++ arg ++ "'"
printUsage :: IO ()
@@ -147,7 +178,12 @@ printUsage = do
," test looks like: '^group1/group2/testname$'."
," --hedgehog-replay '{skip} {seed}'"
," Skip to a particular generated Hedgehog test. Should be used"
- ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set."]
+ ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set."
+ ," --hedgehog-shrinks NUM"
+ ," Limit the number of shrinking steps."
+ ," --parallel Run tests in parallel."
+ ," --verbose Currently only has an effect with --parallel. Also shows OK"
+ ," and timing for test groups, not only individual tests."]
defaultMain :: TestTree -> IO ()
defaultMain tree = do
@@ -165,58 +201,180 @@ runTests options = \tree' ->
return (ExitFailure 1)
Just tree -> do
isterm <- hIsTerminalDevice stdout
- let M m = let ?maxlen = computeMaxLen tree
- ?istty = isterm
- in go 0 id tree
starttm <- getCurrentTime
- (success, stats) <- runStateT m initStats
+ statsRef <- newIORef initStats
+ success <-
+ let ?stats = statsRef
+ ?options = options
+ ?maxlen = computeMaxLen tree
+ ?istty = isterm
+ in if optsParallel options
+ then do nproc <- getNumProcessors
+ setNumCapabilities nproc
+ displayConsoleRegions $
+ withWorkerPool nproc $ \pool -> do
+ let ?pool = pool
+ successVar <- newEmptyMVar
+ runTreePar Nothing [] [] tree successVar
+ readMVar successVar
+ else isJust <$> runTreeSeq 0 [] tree
+ stats <- readIORef statsRef
endtm <- getCurrentTime
let ?istty = isterm in printStats stats (diffUTCTime endtm starttm)
- return (if isJust success then ExitSuccess else ExitFailure 1)
+ return (if success then ExitSuccess else ExitFailure 1)
+
+-- | Returns when all jobs in this tree have been scheduled. When all jobs are
+-- done, the outvar is filled with whether all tests in this tree were
+-- successful.
+-- The mparregion is the parent region to take over, if any. Having a parent
+-- region to take over implies that we are currently executing in a worker and
+-- can hence run blocking user code directly in the current thread.
+runTreePar :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int])
+ => Maybe ConsoleRegion -> [Int] -> [String] -> TestTree -> MVar Bool -> IO ()
+-- TODO: handle collapse somehow?
+runTreePar mparregion revidxlist revpath (Group _ name trees) outvar = do
+ let path = intercalate "/" (reverse (name : revpath))
+ -- outputConcurrent $ "! " ++ path ++ ": Started\n"
+
+ -- If we have exactly one child and we're currently running in a worker, we can continue doing so
+ mparregion2 <- case trees of
+ [_] -> return mparregion
+ _ -> do -- If not, we have to close the region (and implicitly relinquish the worker job)
+ forM_ mparregion closeConsoleRegion
+ return Nothing
+
+ starttm <- getCurrentTime
+ suboutvars <- forM (zip trees [0..]) $ \(tree, idx) -> do
+ suboutvar <- newEmptyMVar
+ runTreePar mparregion2 (idx : revidxlist) (name : revpath) tree suboutvar
+ return suboutvar
+
+ -- outputConcurrent $ "! " ++ path ++ ": Scheduled all\n"
+
+ -- If we took over the parent region then this readMVar will resolve
+ -- immediately and the forkIO would be unnecessary. Meh.
+ _ <- forkIO $ do
+ success <- and <$> traverse readMVar suboutvars
+ endtm <- getCurrentTime
+ -- outputConcurrent $ "! " ++ path ++ ": Done\n"
+ if success && optsVerbose ?options
+ then let text = path ++ ": " ++ ansiGreen ++ "OK" ++ ansiReset ++ " " ++
+ prettyDuration False (realToFrac (diffUTCTime endtm starttm))
+ in outputConcurrent (text ++ "\n")
+ else return ()
+ putMVar outvar success
+
+ return ()
+
+runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runResource topmparregion 1 toptree topoutvar
where
- -- If all tests are successful, returns the number of output lines produced
- go :: (?maxlen :: Int, ?istty :: Bool) => Int -> (String -> String) -> TestTree -> M (Maybe Int)
- go indent path (Group collapse name trees) = do
- liftIO $ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout
- starttm <- liftIO getCurrentTime
- mlns <- fmap (fmap sum . sequence) . forM trees $
- go (indent + 1) (path . (name++) . ('/':))
- endtm <- liftIO getCurrentTime
- case mlns of
- Just lns | collapse, ?istty -> do
- let thislen = 2*indent + length name
- liftIO $ putStrLn $ concat (replicate (lns+1) "\x1B[A\x1B[2K") ++ "\x1B[G" ++
- replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++
- "\x1B[32mOK\x1B[0m" ++
- prettyDuration False (realToFrac (diffUTCTime endtm starttm))
- return (Just 1)
- _ -> return mlns
- go indent path (Resource make cleanup fun) = do
- value <- liftIO make
- success <- go indent path (fun value)
- liftIO $ cleanup value
- return success
- go indent path (HP name (H.Property config test)) = do
+ runResource
+ :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int])
+ => Maybe ConsoleRegion -> Int -> TestTree -> MVar Bool -> IO ()
+ runResource mparregion depth (Resource inhname make cleanup fun) outvar = do
+ let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")"
+ path = intercalate "/" (reverse (pathitem : revpath))
+ idxlist = reverse revidxlist
+ -- outputConcurrent $ "! " ++ path ++ ": R Submitting\n"
+ submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do
+ setConsoleRegion makeRegion ('|' : path ++ " [R] making...")
+
+ -- outputConcurrent $ "! " ++ path ++ ": R Making\n"
+ value <- make -- TODO: catch exceptions
+ -- outputConcurrent $ "! " ++ path ++ ": R Made\n"
+
+ -- outputConcurrent $ "! " ++ path ++ ": R Running subtree\n"
+ suboutvar <- newEmptyMVar
+ runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion
+ -- outputConcurrent $ "! " ++ path ++ ": R Scheduled subtree\n"
+
+ _ <- forkIO $ do
+ success <- readMVar suboutvar
+ -- outputConcurrent $ "! " ++ path ++ ": R Subtree done, scheduling cleanup\n"
+ poolSubmit ?pool idxlist (Just outvar) $ do
+ cleanupRegion <- openConsoleRegion Linear
+ setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...")
+ -- outputConcurrent $ "! " ++ path ++ ": R Cleaning up\n"
+ cleanup value -- TODO: catch exceptions
+ -- outputConcurrent $ "! " ++ path ++ ": R Cleanup done\n"
+ closeConsoleRegion cleanupRegion
+ return success
+ return ()
+ runResource mparregion _ tree outvar = runTreePar mparregion revidxlist revpath tree outvar
+
+runTreePar mparregion revidxlist revpath (HP name prop) outvar = do
+ let path = intercalate "/" (reverse (name : revpath))
+ idxlist = reverse revidxlist
+
+ submitOrRunIn mparregion idxlist (Just outvar) $ \region -> do
+ -- outputConcurrent $ "! " ++ path ++ ": Running"
+ let prefix = path ++ " [T]"
+ setConsoleRegion region ('|' : prefix)
+ let progressHandler report = do
+ str <- H.renderProgress H.EnableColor (Just (fromString "")) report
+ setConsoleRegion region ('|' : prefix ++ ": " ++ replace '\n' " " str)
+ (ok, rendered) <- runHP progressHandler revpath name prop
+ finishConsoleRegion region (path ++ ": " ++ rendered)
+ return ok
+
+submitOrRunIn :: (?pool :: WorkerPool [Int])
+ => Maybe ConsoleRegion -> [Int] -> Maybe (MVar a) -> (ConsoleRegion -> IO a) -> IO ()
+submitOrRunIn Nothing idxlist outvar fun =
+ poolSubmit ?pool idxlist outvar (openConsoleRegion Linear >>= fun)
+submitOrRunIn (Just reg) _idxlist outvar fun = do
+ result <- fun reg
+ forM_ outvar $ \mvar -> putMVar mvar result
+
+-- | If all tests are successful, returns the number of output lines produced
+runTreeSeq :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool)
+ => Int -> [String] -> TestTree -> IO (Maybe Int)
+runTreeSeq indent revpath (Group opts name trees) = do
+ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout
+ starttm <- getCurrentTime
+ mlns <- fmap (fmap sum . sequence) . forM trees $
+ runTreeSeq (indent + 1) (name : revpath)
+ endtm <- getCurrentTime
+ case mlns of
+ Just lns | goCollapse opts, ?istty -> do
let thislen = 2*indent + length name
- let outputPrefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' '
- when ?istty $ liftIO $ putStr outputPrefix >> hFlush stdout
+ putStrLn $ concat (replicate (lns+1) (ANSI.cursorUpCode 1 ++ ANSI.clearLineCode)) ++
+ ANSI.setCursorColumnCode 0 ++
+ replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++
+ ansiGreen ++ "OK" ++ ansiReset ++
+ prettyDuration False (realToFrac (diffUTCTime endtm starttm))
+ return (Just 1)
+ _ -> return ((+1) <$> mlns)
+runTreeSeq indent path (Resource _ make cleanup fun) = do
+ value <- make
+ success <- runTreeSeq indent path (fun value)
+ cleanup value
+ return success
+runTreeSeq indent path (HP name prop) = do
+ let thislen = 2*indent + length name
+ let prefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' '
+ when ?istty $ putStr prefix >> hFlush stdout
+ (ok, rendered) <- runHP (outputProgress (?maxlen + 2)) path name prop
+ putStrLn ((if ?istty then ANSI.clearFromCursorToLineEndCode else prefix) ++ rendered) >> hFlush stdout
+ return (if ok then Just 1 else Nothing)
- let (config', seedfun) = applyHedgehogOptions options config
- seed <- seedfun
+runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int)
+ => (H.Report H.Progress -> IO ())
+ -> [String]
+ -> String -> H.Property -> IO (Bool, String)
+runHP progressPrinter revpath name (H.Property config test) = do
+ let (config', seedfun) = applyHedgehogOptions ?options config
+ seed <- seedfun
- starttm <- liftIO getCurrentTime
- report <- liftIO $ H.checkReport config' 0 seed test (outputProgress (?maxlen + 2))
- endtm <- liftIO getCurrentTime
+ starttm <- getCurrentTime
+ report <- H.checkReport config' 0 seed test progressPrinter
+ endtm <- getCurrentTime
- liftIO $ do
- when (not ?istty) $ putStr outputPrefix
- printResult report (path name) (diffUTCTime endtm starttm)
- hFlush stdout
+ rendered <- renderResult report (intercalate "/" (reverse (name : revpath))) (diffUTCTime endtm starttm)
- let ok = H.reportStatus report == H.OK
- modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats
- , statsTotal = 1 + statsTotal stats }
- return (if ok then Just 1 else Nothing)
+ let ok = H.reportStatus report == H.OK
+ modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats
+ , statsTotal = 1 + statsTotal stats }
+ return (ok, rendered)
applyHedgehogOptions :: MonadIO m => Options -> H.PropertyConfig -> (H.PropertyConfig, m H.Seed)
applyHedgehogOptions opts config0 =
@@ -232,32 +390,65 @@ outputProgress :: (?istty :: Bool) => Int -> H.Report H.Progress -> IO ()
outputProgress indent report
| ?istty = do
str <- H.renderProgress H.EnableColor (Just (fromString "")) report
- putStr (replace '\n' " " str ++ "\x1B[" ++ show (indent+1) ++ "G")
+ putStr (replace '\n' " " str ++ ANSI.setCursorColumnCode indent)
hFlush stdout
| otherwise = return ()
-printResult :: (?istty :: Bool) => H.Report H.Result -> String -> NominalDiffTime -> IO ()
-printResult report path timeTaken = do
+renderResult :: H.Report H.Result -> String -> NominalDiffTime -> IO String
+renderResult report path timeTaken = do
str <- H.renderResult H.EnableColor (Just (fromString "")) report
case H.reportStatus report of
- H.OK -> putStrLn (ansi "\x1B[K" ++ str ++ prettyDuration False (realToFrac timeTaken))
+ H.OK -> return (str ++ prettyDuration False (realToFrac timeTaken))
H.Failed failure -> do
let H.Report { H.reportTests = count, H.reportDiscards = discards } = report
replayInfo = H.skipCompress (H.SkipToShrink count discards (H.failureShrinkPath failure)) ++
" " ++ show (H.reportSeed report)
suffix = "\n Flags to reproduce: `-p '" ++ path ++ "' --hedgehog-replay '" ++ replayInfo ++ "'`"
- putStrLn (ansi "\x1B[K" ++ str ++ suffix)
- _ -> putStrLn (ansi "\x1B[K" ++ str)
+ return (str ++ suffix)
+ _ -> return str
printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO ()
printStats stats timeTaken
| statsOK stats == statsTotal stats = do
- putStrLn $ ansi "\x1B[32m" ++ "All " ++ show (statsTotal stats) ++
- " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m"
+ putStrLn $ ansiGreen ++ "All " ++ show (statsTotal stats) ++
+ " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset
| otherwise =
let nfailed = statsTotal stats - statsOK stats
- in putStrLn $ ansi "\x1B[31m" ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++
- " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m"
+ in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++
+ " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset
+
+
+newtype WorkerPool k = WorkerPool (TVar (PQ.MinPQueue k (Terminate PoolJob)))
+data PoolJob = forall a. PoolJob (IO a) (Maybe (MVar a))
+data Terminate a = Terminate | Value a
+ deriving (Eq, Ord) -- Terminate sorts before Value
+
+withWorkerPool :: Ord k => Int -> (WorkerPool k -> IO a) -> IO a
+withWorkerPool numWorkers k = do
+ chan <- newTVarIO PQ.empty
+ threads <- forM [0..numWorkers-1] (\i -> forkOn i (worker i chan))
+ k (WorkerPool chan) `finally` do
+ replicateM_ numWorkers (atomically $ writeTVar chan PQ.empty)
+ forM_ threads killThread
+ where
+ worker :: Ord k => Int -> TVar (PQ.MinPQueue k (Terminate PoolJob)) -> IO ()
+ worker idx chan = do
+ job <- atomically $
+ readTVar chan >>= \case
+ PQ.Empty -> retry
+ (_, j) PQ.:< q -> writeTVar chan q >> return j
+ case job of
+ Value (PoolJob action mmvar) -> do
+ -- outputConcurrent $ "[" ++ show idx ++ "] got job\n"
+ result <- action
+ forM_ mmvar $ \mvar -> putMVar mvar result
+ worker idx chan
+ Terminate -> return ()
+
+poolSubmit :: Ord k => WorkerPool k -> k -> Maybe (MVar a) -> IO a -> IO ()
+poolSubmit (WorkerPool chan) key mmvar action =
+ atomically $ modifyTVar chan $ PQ.insert key (Value (PoolJob action mmvar))
+
prettyDuration :: Bool -> Double -> String
prettyDuration False x | x < 0.5 = ""
@@ -273,3 +464,23 @@ replace x ys = concatMap (\y -> if y == x then ys else [y])
ansi :: (?istty :: Bool) => String -> String
ansi | ?istty = id
| otherwise = const ""
+
+ansiRed, ansiGreen, ansiReset :: (?istty :: Bool) => String
+ansiRed = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red])
+ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green])
+ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset])
+
+-- getTermIsDark :: IO Bool
+-- getTermIsDark = do
+-- mclr <- ANSI.getLayerColor ANSI.Background
+-- case mclr of
+-- Nothing -> return True
+-- Just (RGB r g b) ->
+-- let cvt n = fromIntegral n / fromIntegral (maxBound `asTypeOf` n)
+-- in return $ (cvt r + cvt g + cvt b) / 3 < (0.5 :: Double)
+
+-- ansiRegionBg :: (?istty :: Bool, ?termisdark :: Bool) => String
+-- ansiRegionBg
+-- | not ?istty = ""
+-- | ?termisdark = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.0 0.05 0.1)]
+-- | otherwise = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.95 0.95 1.0)]
diff --git a/test/Main.hs b/test/Main.hs
index f5e4a3c..4bc9082 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -25,6 +25,7 @@ import Test.Framework
import Array
import AST hiding ((.>))
+import AST.Count (pruneExpr)
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
@@ -63,18 +64,18 @@ simplifyIters iters env | Dict <- envKnown env =
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
--- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
-gradientByCHAD simplIters env term input =
- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- (out, grad) = interpretOpen False env input dterm
- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
+-- -- In addition to the gradient, also returns the pretty-printed differentiated term.
+-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
+-- gradientByCHAD simplIters env term input =
+-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
+-- (out, grad) = interpretOpen False env input dterm
+-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
--- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
-gradientByCHAD' simplIters env term input =
- second (second (toTanE env input)) $
- gradientByCHAD simplIters env term input
+-- -- In addition to the gradient, also returns the pretty-printed differentiated term.
+-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
+-- gradientByCHAD' simplIters env term input =
+-- second (second (toTanE env input)) $
+-- gradientByCHAD simplIters env term input
gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
@@ -85,19 +86,19 @@ extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
extendDN (STEither a _) (Left x) = Left <$> extendDN a x
extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
+extendDN (STLEither _ _) Nothing = pure Nothing
+extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
+extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
extendDN (STScal sty) x = case sty of
- STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
- STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
+ STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d)
+ STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d)
STI32 -> pure x
STI64 -> pure x
STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"
-extendDN (STLEither _ _) Nothing = pure Nothing
-extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
-extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
@@ -116,6 +117,10 @@ closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h
closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
closeIshT' _ STEither{} _ _ = False
+closeIshT' _ (STLEither _ _) Nothing Nothing = True
+closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
+closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
+closeIshT' _ STLEither{} _ _ = False
closeIshT' _ (STMaybe _) Nothing Nothing = True
closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
closeIshT' _ STMaybe{} _ _ = False
@@ -128,10 +133,6 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
closeIshT' h (STScal STF64) x y = closeIsh' h x y
closeIshT' _ (STScal STBool) x y = x == y
closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
-closeIshT' _ (STLEither _ _) Nothing Nothing = True
-closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
-closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
-closeIshT' _ STLEither{} _ _ = False
closeIshT :: STy t -> Rep t -> Rep t -> Bool
closeIshT = closeIshT' 1e-5
@@ -233,19 +234,19 @@ genValue topty tpl = case topty of
STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
,liftV Right <$> genValue b (emptyTpl b)]
+ STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
+ ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
+ ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
STMaybe t -> Gen.choice [return (Value Nothing)
,liftV Just <$> genValue t (emptyTpl t)]
STArr n t -> genShape n tpl >>= lift . genArray t
STScal sty -> case sty of
- STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
- STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10)
+ STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10)
STI32 -> genInt
STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
- STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
- ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
- ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
where
genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
genInt = do
@@ -302,10 +303,12 @@ adTestGen name expr envGenerator =
exprS = simplifyFix expr
in withCompiled env expr $ \primalfun ->
withCompiled env (simplifyFix expr) $ \primalSfun ->
- testGroupCollapse name
+ groupSetCollapse $ testGroup name
[adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
,adTestGenFwd env envGenerator exprS
- ,adTestGenChad env envGenerator expr exprS primalSfun]
+ ,testGroup "chad"
+ [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun
+ ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]]
adTestGenPrimal :: SList STy env -> Gen (SList Value env)
-> Ex env R -> Ex env R
@@ -336,24 +339,31 @@ adTestGenFwd env envGenerator exprS =
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
-adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
+adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env)
-> Ex env R -> Ex env R
-> (SList Value env -> IO Double)
-> TestTree
-adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
- let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr
dtermChadS = simplifyFix dtermChad0
- dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+ dtermChadSUS = simplifyFix $ unMonoid dtermChadS
+ dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
dtermSChadS = simplifyFix dtermSChad0
+ dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
+ dtermSChadSUSP = pruneExpr env dtermSChadSUS
in
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->
- testProperty "chad" $ property $ do
+ withCompiled env dtermSChadSUSP $ \dcompSChadSUSP ->
+ testProperty testname $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
- -- pack Text for less GC pressure (these values are retained for some reason)
- diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
- diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
+ -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason)
+ let dtermChad20 = simplifyN 20 dtermChad0
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20))
+ diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20)))
+ let dtermSChad20 = simplifyN 20 dtermSChad0
+ diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20))
+ diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20)))
input <- forAllWith (showEnv env) envGenerator
outPrimal <- evalIO $ primalSfun input
@@ -363,35 +373,47 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
- let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
- (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
- (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
- (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
- tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
- tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
- tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
- tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
+ (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
+ (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
+ (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
+ tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+ tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP
- (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input)
- let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS
+ (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input)
+ let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP
-- annotate (showEnv (d2e env) gradChad0)
-- annotate (showEnv (d2e env) gradChadS)
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff outCompSChadS closeIsh outPrimal
+ annotate (ppExpr env dtermSChadSUSP)
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outChadSUS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outSChadSUS closeIsh outPrimal
+ diff outSChadSUSP closeIsh outPrimal
+ diff outCompSChadSUSP closeIsh outPrimal
let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
- diff tansChad closeIshE' tansFwd
- diff tansChadS closeIshE' tansFwd
- diff tansSChad closeIshE' tansFwd
- diff tansSChadS closeIshE' tansFwd
- diff tansCompSChadS closeIshE' tansFwd
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansChadSUS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansSChadSUS closeIshE' tansFwd
+ diff tansSChadSUSP closeIshE' tansFwd
+ diff tansCompSChadSUSP closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
@@ -409,7 +431,7 @@ gen_gmm = do
vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
- vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ vgamma <- Gen.realFloat (Range.constant (-10) 10)
vm <- Gen.integral (Range.linear 0 5)
let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
k2 = 0.5 * vgamma * vgamma
@@ -435,11 +457,30 @@ gen_neural = do
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+term_build0 :: Ex '[TArr N0 R] R
+term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $
+ idx0 $
+ build SZ (shape #x) $ #idx :-> #x ! #idx
+
term_build1_sum :: Ex '[TVec R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
+term_build1_idx :: Ex '[TVec R] R
+term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $
+ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i))
+
+term_idx_coprod :: Ex '[TVec (TEither R R)] R
+term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ case_ (#x ! pair nil #i)
+ (#a :-> #a * 2)
+ (#b :-> #b * 3)
+
term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
@@ -502,25 +543,32 @@ tests_Compile = testGroup "Compile"
,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
with @(TPair R R) (pair 0.0 0.0) $ #ac :->
- let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
nil
,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
- with @(TMaybe (TPair R R)) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $
+ with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $
+ let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $
nil
,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
let_ #len (snd_ (shape #x)) $
with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
- let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac)
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac)
nil) $
let_ #_ (accum SAPHere nil #x #ac) $
nil
+
+ ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $
+ fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a
+
+ ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $
+ let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $
+ fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d
]
tests_AD :: TestTree
@@ -556,9 +604,7 @@ tests_AD = testGroup "AD"
,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
- ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
- idx0 $
- build SZ (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build0" term_build0
,adTest "build1-sum" term_build1_sum
@@ -566,6 +612,27 @@ tests_AD = testGroup "AD"
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build1-idx" term_build1_idx
+
+ ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ let_ #p (#x ! pair nil #i) $
+ 3 * fst_ #p + 2 * snd_ #p
+
+ ,adTest "idx-coprod" $ term_idx_coprod
+
+ ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $
+ let_ #n (snd_ (shape #arr)) $
+ let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $
+ if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x))
+ (pair (inr (3 * #x)) (exp #x)))) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ let_ #p (#b ! pair nil #i) $
+ case_ (fst_ #p)
+ (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p)
+ (#b :-> #b * 4)
+
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TMat R) #x $ body $
idx0 $ sum1i $ maximum1i #x
@@ -607,6 +674,35 @@ tests_AD = testGroup "AD"
,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
+
+ ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree
+
+ ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $
+ let_ #sh (shape #a) $
+ let_ #n (snd_ #sh * snd_ (fst_ #sh)) $
+ idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a
+
+ ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $
+ let_ #sh (shape #a) $
+ let_ #innern (snd_ #sh) $
+ let_ #n (#innern * snd_ (fst_ #sh)) $
+ let_ #flata (reshape (SS SZ) (pair nil #n) #a) $
+ -- ensure the input array to EReshape is shared
+ idx0 $ sum1i $
+ build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern))
+
+ ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a
+
+ ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a
+
+ ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ let_ #v 2 $
+ idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a
+
+ ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a
]
main :: IO ()