diff options
40 files changed, 4117 insertions, 1449 deletions
diff --git a/README.md b/README.md new file mode 100644 index 0000000..585e3b0 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# chad-fast + +This is work-in-progress research software. If you are interested in techniques +or ideas in this repository, please do [contact me](https://tomsmeding.com), +I'm always happy to chat. diff --git a/bench/Main.hs b/bench/Main.hs index 358ba31..ec9264b 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,11 +1,13 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where @@ -17,6 +19,7 @@ import Data.Kind (Constraint) import GHC.Exts (withDict) import AST +import AST.Count import AST.UnMonoid import Array import qualified CHAD (defaultConfig) @@ -34,8 +37,12 @@ import Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) gradCHAD config term = compile knownEnv $ - simplifyFix $ unMonoid $ simplifyFix $ - ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ + ELet ext (EConst ext STF64 1.0) $ + chad' config knownEnv $ + simplifyFix term type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where @@ -93,18 +100,19 @@ makeGMMInputs = accumConfig :: CHADConfig accumConfig = chcSetAccum CHAD.defaultConfig -main :: IO () -main = defaultMain - [env (return makeNeuralInputs) $ \inputs -> bgroup "neural" - [env (gradCHAD CHAD.defaultConfig neural) $ \fun -> - bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig neural) $ \fun -> - bench "accum" (nfAppIO fun inputs) - ] - ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm" - [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun -> +bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env)))) + => String -> Ex env R -> SList Value env -> Benchmark +bgroupDefaultAccum name term inputs = + bgroup name + [env (gradCHAD CHAD.defaultConfig term) $ \fun -> bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun -> + ,env (gradCHAD accumConfig term) $ \fun -> bench "accum" (nfAppIO fun inputs) ] + +main :: IO () +main = defaultMain + [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural + ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False) + ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil) ] diff --git a/chad-fast.cabal b/chad-fast.cabal index b0ed639..df0409d 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -10,6 +10,9 @@ build-type: Simple library exposed-modules: + -- default ghci module on top + Example + Analysis.Identity Array AST @@ -18,6 +21,8 @@ library AST.Count AST.Env AST.Pretty + AST.Sparse + AST.Sparse.Types AST.SplitLets AST.Types AST.UnMonoid @@ -33,7 +38,6 @@ library Compile.Exec Data Data.VarMap - Example Example.GMM Example.Types ForwardAD @@ -81,7 +85,11 @@ library test-framework exposed-modules: Test.Framework build-depends: base, + ansi-terminal, + concurrent-output, hedgehog, + pqueue, + stm, time, transformers hs-source-dirs: test-framework @@ -20,10 +20,13 @@ module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const +import Data.Functor.Identity +import Data.Int (Int64) import Data.Kind (Type) import Array import AST.Accum +import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types @@ -55,20 +58,37 @@ data Expr x env t where ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b - ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) - ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) - ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) - ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (t1 : t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -83,6 +103,9 @@ data Expr x env t where -- be backpropagated to; 'a' is the inactive part. The dual field of -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x [b, a] t -- ^ regular operation -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass @@ -90,14 +113,28 @@ data Expr x env t where -> Expr x env a -> Expr x env b -> Expr x env t + -- fake halfway checkpointing + ERecompute :: x t -> Expr x env t -> Expr x env t + -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + + -- interface of abstract monoidal types + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- partiality EError :: x a -> STy a -> String -> Expr x env a @@ -200,6 +237,10 @@ typeOf = \case EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -209,11 +250,13 @@ typeOf = \case EOp _ op _ -> opt2 op ECustom _ _ _ _ e _ _ _ _ -> typeOf e + ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil + EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t @@ -245,6 +288,9 @@ extOf = \case EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -252,52 +298,63 @@ extOf = \case EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x + ERecompute x _ -> x EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x + EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = \case - EVar x t i -> EVar (f x) t i - ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) - EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) - EFst x e -> EFst (f x) (mapExt f e) - ESnd x e -> ESnd (f x) (mapExt f e) - ENil x -> ENil (f x) - EInl x t e -> EInl (f x) t (mapExt f e) - EInr x t e -> EInr (f x) t (mapExt f e) - ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) - ENothing x t -> ENothing (f x) t - EJust x e -> EJust (f x) (mapExt f e) - EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) - ELNil x t1 t2 -> ELNil (f x) t1 t2 - ELInl x t e -> ELInl (f x) t (mapExt f e) - ELInr x t e -> ELInr (f x) t (mapExt f e) - ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c) - EConstArr x n t a -> EConstArr (f x) n t a - EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) - EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) - ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) - EUnit x e -> EUnit (f x) (mapExt f e) - EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) - EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) - EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) - EConst x t v -> EConst (f x) t v - EIdx0 x e -> EIdx0 (f x) (mapExt f e) - EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) - EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) - EShape x e -> EShape (f x) (mapExt f e) - EOp x op e -> EOp (f x) op (mapExt f e) - ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) - EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) - EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) - EZero x t e -> EZero (f x) t (mapExt f e) - EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) - EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) - EError x t s -> EError (f x) t s +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case + EVar x t i -> EVar <$> f x <*> pure t <*> pure i + ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body + EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b + EFst x e -> EFst <$> f x <*> travExt f e + ESnd x e -> ESnd <$> f x <*> travExt f e + ENil x -> ENil <$> f x + EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e + EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e + ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b + ENothing x t -> ENothing <$> f x <*> pure t + EJust x e -> EJust <$> f x <*> travExt f e + EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e + ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 + ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e + ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e + ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c + EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a + EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e + EUnit x e -> EUnit <$> f x <*> travExt f e + EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b + EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e + EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EConst x t v -> EConst <$> f x <*> pure t <*> pure v + EIdx0 x e -> EIdx0 <$> f x <*> travExt f e + EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b + EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es + EShape x e -> EShape <$> f x <*> travExt f e + EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e + ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + ERecompute x e -> ERecompute <$> f x <*> travExt f e + EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 + EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e + EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b + EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b + EError x t s -> EError <$> f x <*> pure t <*> pure s substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t substInline repl = @@ -342,6 +399,9 @@ subst' f w = \case EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -349,9 +409,11 @@ subst' f w = \case EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) + ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s @@ -376,11 +438,11 @@ class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy class KnownMTy t where knownMTy :: SMTy t instance KnownMTy TNil where knownMTy = SMTNil @@ -398,11 +460,11 @@ styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- smtyKnown t = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict smtyKnown :: SMTy t -> Dict (KnownMTy t) smtyKnown SMTNil = Dict @@ -423,6 +485,16 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e + _ -> False + eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) @@ -447,27 +519,30 @@ eidxEq (SS n) a b (eidxEq n (EFst ext (EVar ext ty (IS IZ))) (EFst ext (EVar ext ty IZ))) -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr n t <- typeOf arr + , Dict <- styKnown t + = ELet ext arr $ + EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) f -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = - let STArr n t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr n t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELet ext arr1 $ + ELet ext (weakenExpr WSink arr2) $ + EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ + ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ + weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) ezip arr1 arr2 = @@ -489,11 +564,91 @@ eshapeEmpty (SS n) e = (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) -ezeroD2 :: STy t -> Ex env (D2 t) -ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) +eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) +eshapeConst ShNil = ENil ext +eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) + +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body + +-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +ecase e a b + | STEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ECase ext e a b + +elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index e84034b..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,14 +1,13 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where import AST.Types -import CHAD.Types import Data @@ -35,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) - AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) - AcIdx (APLeft p) (TLEither a b) = AcIdx p a - AcIdx (APRight p) (TLEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, shapes info), recursive info) +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t @@ -75,19 +92,23 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 3d99afe..463586a 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -16,6 +16,7 @@ module AST.Bindings where import AST +import AST.Env import Data import Lemmas @@ -27,6 +28,10 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + mapBindings :: (forall env' t'. f env' t' -> g env' t') -> Bindings f env binds -> Bindings g env binds mapBindings _ BTop = BTop @@ -41,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) = let (b', w') = weakenBindings wf w b in (BPush b' (t, wf w' x), WCopy w') +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' weakenOver SNil w = w weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) @@ -62,3 +72,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t letBinds BTop = id letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + +collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' +collectBindings = \env -> fst . go env WId + where + go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) + go _ _ SETop = (BTop, WId) + go (ty `SCons` env) w (SEYesR sub) = + let (bs, w') = go env (WPop w) sub + in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') + go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index feaaa1e..296c021 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} @@ -10,17 +11,31 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} module AST.Count where -import Data.Functor.Const +import Data.Functor.Product +import Data.Some +import Data.Type.Equality import GHC.Generics (Generic, Generically(..)) +import Array import AST import AST.Env import Data +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where + -- | One of the two branches is taken + (<||>) :: a -> a -> a + -- | This code is executed many times + scaleMany :: a -> a + + data Count = Zero | One | Many deriving (Show, Eq, Ord) @@ -30,6 +45,10 @@ instance Semigroup Count where _ <> _ = Many instance Monoid Count where mempty = Zero +instance Occurrence Count where + (<||>) = max + scaleMany Zero = Zero + scaleMany _ = Many data Occ = Occ { _occLexical :: Count , _occRuntime :: Count } @@ -40,120 +59,818 @@ instance Show Occ where showsPrec d (Occ l r) = showParen (d > 10) $ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) +instance Occurrence Occ where + Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) + scaleMany (Occ l c) = Occ l (scaleMany c) + + +data Substruc t t' where + -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below + SsFull :: Substruc t t + SsNone :: Substruc t TNil + SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') + SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') + SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') + SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') + SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements + SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) + where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' +pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) + where SsArr' = SsArr +{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} + +instance Semigroup (Some (Substruc t)) where + Some SsFull <> _ = Some SsFull + _ <> Some SsFull = Some SsFull + Some SsNone <> s = s + s <> Some SsNone = s + Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) + Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) + Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) + Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) + Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) + Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where + mempty = Some SsNone + +instance TestEquality (Substruc t) where + testEquality SsFull s = isFull s + testEquality s SsFull = sym <$> isFull s + testEquality SsNone SsNone = Just Refl + testEquality SsNone _ = Nothing + testEquality _ SsNone = Nothing + testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) + | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +-- simplifySubstruc' s1 $ \s1' f1 -> +-- simplifySubstruc' s2 $ \s2' f2 -> +-- case (s1', s2') of +-- (SsFull, SsFull) -> +-- k SsFull (case (f1, f2) of +-- (ExMapId, ExMapId) -> ExMapId +-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where + MapEmpty '[] = '[] + MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where + OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! + OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where + Some OccEnd <> e = e + e <> Some OccEnd = e + Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where + mempty = Some OccEnd + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where + Some OccEnd <||> e = e + e <||> Some OccEnd = e + Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + + scaleMany (Some OccEnd) = Some OccEnd + scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s + | Some env' <- onehotOccEnv i v s + = Some (OccPush env' mempty SsNone) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone + +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd + +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i + +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) + | Some (Pair s' i') <- occEnvPrjS e i + = Some (Pair s' (IS i')) + +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of + _ | Just Refl <- testEquality topsbig topssmall -> ex + + (SsFull, SsFull) -> ex + (SsNone, SsNone) -> ex + (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" + (_, SsNone) -> + case typeOf ex of + STNil -> ex + _ -> use ex $ ENil ext + + (SsPair s1 s2, SsPair s1' s2') -> + eunPair ex $ \_ e1 e2 -> + EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) + (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex + (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex + + (SsEither s1 s2, SsEither s1' s2') + | STEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in ecase ex + (EInl ext (typeOf e2) e1) + (EInr ext (typeOf e1) e2) + (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex + (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex + + (SsLEither s1 s2, SsLEither s1' s2') + | STLEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in elcase ex + (ELNil ext (typeOf e1) (typeOf e2)) + (ELInl ext (typeOf e2) e1) + (ELInr ext (typeOf e1) e2) + (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex + (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex + + (SsMaybe s1, SsMaybe s1') + | STMaybe t1 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + in emaybe ex + (ENothing ext (typeOf e1)) + (EJust ext e1) + (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex + (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex + + (SsArr s1, SsArr s2) + | STArr n t <- typeOf ex -> + elet ex $ + EBuild ext n (EShape ext (evar IZ)) $ + projectSmallerSubstruc s1 s2 + (EIdx ext (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex + (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + + (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" + (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex + (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where + EMRest :: Bool -> EnvMask env + EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = - getConst . occCountGeneral - (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) - (\(Const o) -> Const o) - (\(Const o1) (Const o2) -> Const (o1 <||> o2)) - (\(Const o) -> Const (scaleMany o)) +occCount idx ex + | Some env <- occCountAll ex + = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) + where + fullOccEnv :: SList f env -> OccEnv () env env + fullOccEnv SNil = OccEnd + fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse +-- * k: is passed the actual environment usage of this expression, including +-- occurrence counts. The callback reconstructs a new expression in an +-- updated "response" environment. The response must be at least as large as +-- the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t + -> (forall env'. OccEnv Occ env env' + -- response OccEnv must be at least as large as the OccEnv returned above + -> (forall env''. OccEnv () env env'' -> Ex env'' t') + -> r) + -> r +occCountX initialS topexpr k = case topexpr of + EVar _ t i -> + withSome (onehotOccEnv i (Occ One One) s) $ \env -> + k env $ \env' -> + withSome (occEnvPrjS env' i) $ \(Pair s' i') -> + projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') + ELet _ rhs body -> + occCountX s body $ \envB mkbody -> + occEnvPop' envB $ \envB' s1 -> + occCountX s1 rhs $ \envR mkrhs -> + withSome (Some envB' <> Some envR) $ \env -> + k env $ \env' -> + ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) + EPair _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsPair' s1 s2 -> + occCountX s1 a $ \env1 mka -> + occCountX s2 b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPair ext (mka env') (mkb env') + EFst _ e -> + occCountX (SsPair s SsNone) e $ \env1 mke -> + k env1 $ \env' -> + EFst ext (mke env') + ESnd _ e -> + occCountX (SsPair SsNone s) e $ \env1 mke -> + k env1 $ \env' -> + ESnd ext (mke env') + ENil _ -> + case s of + SsFull -> k OccEnd (\_ -> ENil ext) + SsNone -> k OccEnd (\_ -> ENil ext) + EInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + EInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + EInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + EInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + ECase _ e a b -> + occCountX s a $ \env1' mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env1' $ \env1 s1 -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) + ENothing _ t -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EJust _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsMaybe s' -> + occCountX s' e $ \env1 mke -> + k env1 $ \env' -> + EJust ext (mke env') + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EMaybe _ a b e -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsMaybe s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') + ELNil _ t1 t2 -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + ELInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + ELInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELCase _ e a b c -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occCountX s c $ \env3' mkc -> + occEnvPop' env2' $ \env2 s1 -> + occEnvPop' env3' $ \env3 s2 -> + occCountX (SsLEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> + k env $ \env' -> + ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + + EConstArr _ n t x -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) + + EBuild _ n a b -> + case s of + SsNone -> + occCountX SsFull a $ \env1 mka -> + occCountX SsNone b $ \env2'' mkb -> + withSome (scaleMany (Some env2'')) $ \env2' -> + occEnvPop' env2' $ \env2 s2 -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EBuild ext n (mka env') $ + use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ + ENil ext) $ + ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX s' b $ \env2'' mkb -> + withSome (scaleMany (Some env2'')) $ \env2' -> + occEnvPop' env2' $ \env2 s2 -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EBuild ext n (mka env') $ + elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + + EFold1Inner _ commut a b c -> + occCountX SsFull a $ \env1''' mka -> + withSome (scaleMany (Some env1''')) $ \env1'' -> + occEnvPop' env1'' $ \env1' s2 -> + occEnvPop' env1' $ \env1 s1 -> + let s0 = case s of + SsNone -> Some SsNone + SsArr' s' -> Some s' in + withSome (Some s1 <> Some s2 <> s0) $ \sElt -> + occCountX sElt b $ \env2 mkb -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush (OccPush env' () sElt) () sElt)) + (mkb env') (mkc env') + + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e + + EUnit _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' s' -> + occCountX s' e $ \env mke -> + k env $ \env' -> + EUnit ext (mke env') + + EReplicate1Inner _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX (SsArr s') b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReplicate1Inner ext (mka env') (mkb env') + + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + + EFold1InnerD1 _ cm e1 e2 e3 -> + case s of + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ + EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm ef ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkebog env') (mked env') + + EConst _ t x -> + k OccEnd $ \_ -> + case s of + SsNone -> ENil ext + SsFull -> EConst ext t x + + EIdx0 _ e -> + occCountX (SsArr s) e $ \env1 mke -> + k env1 $ \env' -> + EIdx0 ext (mke env') + + EIdx1 _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX (SsArr s') a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx1 ext (mka env') (mkb env') + + EIdx _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + _ -> + occCountX (SsArr s) a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx ext (mka env') (mkb env') + + EShape _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX (SsArr SsNone) e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EShape ext (mke env') + EOp _ op e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EOp ext op (mke env') -data OccEnv env where - OccEnd :: OccEnv env -- not necessarily top! - OccPush :: OccEnv env -> Occ -> OccEnv (t : env) + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') -instance Semigroup (OccEnv env) where - OccEnd <> e = e - e <> OccEnd = e - OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') + ERecompute _ e -> + occCountX s e $ \env1 mke -> + k env1 $ \env' -> + ERecompute ext (mke env') -instance Monoid (OccEnv env) where - mempty = OccEnd + EWith _ t a b -> + case s of + SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with + occCountX SsNone b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + withSome (case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone) $ \s1' -> + occCountX s1' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ + ENil ext + SsPair sB sA -> + occCountX sB b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + let s1' = case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone in + withSome (Some sA <> s1') $ \sA' -> + occCountX sA' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ + EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) + SsFull -> occCountX (SsPair SsFull SsFull) topexpr k -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty + EAccum _ t p a sp b e -> + -- TODO: do better! + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull e $ \env3 mke -> + withSome (Some env1 <> Some env2) $ \env12 -> + withSome (Some env12 <> Some env3) $ \env -> + k env $ \env' -> + case s of {SsFull -> id; SsNone -> id} $ + EAccum ext t p (mka env') sp (mkb env') (mke env') -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') + EZero _ t e -> + occCountX (subZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EZero ext (applySubstrucM s t) (mke env') + where + subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) + subZeroInfo SsFull = SsFull + subZeroInfo SsNone = SsNone + subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) + subZeroInfo SsEither{} = error "Either is not a monoid" + subZeroInfo SsLEither{} = SsNone + subZeroInfo SsMaybe{} = SsNone + subZeroInfo (SsArr s') = SsArr (subZeroInfo s') + subZeroInfo SsAccum{} = error "Accum is not a monoid" -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) + EDeepZero _ t e -> + occCountX (subDeepZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EDeepZero ext (applySubstrucM s t) (mke env') + where + subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) + subDeepZeroInfo SsFull = SsFull + subDeepZeroInfo SsNone = SsNone + subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo SsEither{} = error "Either is not a monoid" + subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') + subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') + subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd + EPlus _ t a b -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPlus ext (applySubstrucM s t) (mka env') (mkb env') -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv + EOneHot _ t p a b -> + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> -- TODO: do better + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') -occCountGeneral :: forall r env t x. - (forall env'. Monoid (r env')) - => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot - -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env'. r env' -> r env' -> r env') -- ^ alternation - -> (forall env'. r env' -> r env') -- ^ scale-many - -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId + EError _ t msg -> + k OccEnd $ \_ -> EError ext (applySubstruc s t) msg where - go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' - go w = \case - EVar _ _ i -> onehot w i (Occ One One) - ELet _ rhs body -> re rhs <> re1 body - EPair _ a b -> re a <> re b - EFst _ e -> re e - ESnd _ e -> re e - ENil _ -> mempty - EInl _ _ e -> re e - EInr _ _ e -> re e - ECase _ e a b -> re e <> (re1 a `alter` re1 b) - ENothing _ _ -> mempty - EJust _ e -> re e - EMaybe _ a b e -> re a <> re1 b <> re e - ELNil _ _ _ -> mempty - ELInl _ _ e -> re e - ELInr _ _ e -> re e - ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c) - EConstArr{} -> mempty - EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c - ESum1Inner _ e -> re e - EUnit _ e -> re e - EReplicate1Inner _ a b -> re a <> re b - EMaximum1Inner _ e -> re e - EMinimum1Inner _ e -> re e - EConst{} -> mempty - EIdx0 _ e -> re e - EIdx1 _ a b -> re a <> re b - EIdx _ a b -> re a <> re b - EShape _ e -> re e - EOp _ _ e -> re e - ECustom _ _ _ _ _ _ _ a b -> re a <> re b - EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e - EZero _ _ e -> re e - EPlus _ _ a b -> re a <> re b - EOneHot _ _ _ a b -> re a <> re b - EError{} -> mempty - where - re :: Monoid (r env') => Expr x env' t'' -> r env' - re = go w + s = simplifySubstruc (typeOf topexpr) initialS - re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' - re1 = unpush . go (WSink .> w) + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr' SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = - deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = - deleteUnused env occenv $ \sub -> +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = + deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = + deleteUnused env (Some occenv) $ \sub -> case count of Zero -> k (SENo sub) - _ -> k (SEYes sub) + _ -> k (SEYesR sub) unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t unsafeWeakenWithSubenv = \sub -> @@ -162,7 +879,7 @@ unsafeWeakenWithSubenv = \sub -> Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") where sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYes _) = Just IZ + sinkViaSubenv IZ (SEYesR _) = Just IZ sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index 4f34166..85faba3 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,59 +1,95 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Env where +import Data.Type.Equality + +import AST.Sparse import AST.Weaken +import CHAD.Types import Data -- | @env'@ is a subset of @env@: each element of @env@ is either included in -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') - SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} -subList :: SList f env -> Subenv env env' -> SList f env' +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: SList f env -> Subenv env env +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) -subenvNone :: SList f env -> Subenv env '[] +subenvNone :: SList f env -> Subenv' s env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 +-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r +-- subenvSplit SNil sub k = k SETop sub +-- subenvSplit (SCons _ list) (SENo sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SENo sub1) sub2 +-- subenvSplit (SCons _ list) (SEYes s sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SEYes s sub1) sub2 + +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub sinkWithSubenv (SENo sub) = sinkWithSubenv sub -wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index e09f3ae..68fc629 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where @@ -25,6 +26,7 @@ import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count +import AST.Sparse.Types import CHAD.Types import Data @@ -70,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -145,12 +148,20 @@ ppExpr' d val expr = case expr of EMaybe _ a b e -> do let STMaybe t = typeOf e - a' <- ppExpr' 11 val a + e' <- ppExpr' 0 val e + a' <- ppExpr' 0 val a name <- genNameIfUsedIn t IZ b b' <- ppExpr' 0 (Const name `SCons` val) b - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ - ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e'] + return $ ppParen (d > 0) $ + align $ + group (flatAlt + (annotate AKey (ppString "case") <> ppX expr <+> e' + <> hardline <> annotate AKey (ppString "of")) + (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) + <> hardline + <> indent 2 + (ppString "Nothing" <+> ppString "->" <+> a' + <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') ELNil _ _ _ -> return (ppString "LNil") @@ -199,8 +210,7 @@ ppExpr' d val expr = case expr of a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] @@ -225,6 +235,34 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' + + EFold1InnerD1 _ cm a b c -> do + name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a + name2 <- genNameIfUsedIn (typeOf b) IZ a + a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -280,6 +318,10 @@ ppExpr' d val expr = case expr of ,e1' ,e2'] + ERecompute _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] + EWith _ t e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 @@ -292,18 +334,24 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] EZero _ t e1 -> do e1' <- ppExpr' 11 val e1 return $ ppParen (d > 0) $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b @@ -356,6 +404,20 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -388,6 +450,7 @@ ppSTy' :: Int -> STy t -> Doc q ppSTy' _ STNil = ppString "1" ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t ppSTy' d (STArr n t) = ppParen (d > 10) $ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t @@ -398,7 +461,6 @@ ppSTy' _ (STScal sty) = ppString $ case sty of STF64 -> "f64" STBool -> "bool" ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t -ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b ppSMTy :: Int -> SMTy t -> String ppSMTy d ty = render $ ppSMTy' d ty diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs new file mode 100644 index 0000000..2a29799 --- /dev/null +++ b/src/AST/Sparse.hs @@ -0,0 +1,287 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE RankNTypes #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where + +import Data.Type.Equality + +import AST +import AST.Sparse.Types +import Data (SBool(..)) + + +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + + +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This simplifies the sparsePlusS code. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. TODO can this be improved? +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs new file mode 100644 index 0000000..10cac4e --- /dev/null +++ b/src/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AST.Sparse.Types where + +import AST.Types + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 159934d..d276e44 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) @@ -35,6 +35,13 @@ splitLets' = \sub -> \case EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -54,6 +61,7 @@ splitLets' = \sub -> \case EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) @@ -61,9 +69,11 @@ splitLets' = \sub -> \case EShape x e -> EShape x (splitLets' sub e) EOp x op e -> EOp x op (splitLets' sub e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) + ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s @@ -87,15 +97,42 @@ splitLets' = \sub -> \case -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') + bs1 = fst (weakenBindingsE WSink bs1') (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] @@ -123,11 +160,11 @@ split typ = case typ of STPair{} -> splitRec (EVar ext typ IZ) typ STNil -> other STEither{} -> other + STLEither{} -> other STMaybe{} -> other STArr{} -> other STScal{} -> other STAccum{} -> other - STLEither{} -> other where other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) other = (Point typ IZ, BTop) @@ -142,11 +179,11 @@ splitRec rhs typ = case typ of (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) STEither{} -> other + STLEither{} -> other STMaybe{} -> other STArr{} -> other STScal{} -> other STAccum{} -> other - STLEither{} -> other where other :: (Pointers (t : env) t, Bindings Ex env '[t]) other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index efb1e04..4ddcb50 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -5,9 +5,9 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-} module AST.Types where import Data.Int (Int32, Int64) @@ -23,12 +23,11 @@ type data Ty = TNil | TPair Ty Ty | TEither Ty Ty + | TLEither Ty Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Ty -- ^ contained type must be a monoid type - -- sparse monoid types - | TLEither Ty Ty type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -37,12 +36,11 @@ data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) + STLEither :: STy a -> STy b -> STy (TLEither a b) STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) - -- sparse monoid types - STLEither :: STy a -> STy b -> STy (TLEither a b) deriving instance Show (STy t) instance GCompare STy where @@ -53,6 +51,8 @@ instance GCompare STy where STPair{} _ -> GLT ; _ STPair{} -> GGT (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') STEither{} _ -> GLT ; _ STEither{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STLEither{} _ -> GLT ; _ STLEither{} -> GGT (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') @@ -60,9 +60,7 @@ instance GCompare STy where (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - STAccum{} _ -> GLT ; _ STAccum{} -> GGT - (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT + -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq @@ -173,15 +171,25 @@ type family ScalIsIntegral t where ScalIsIntegral TBool = False -- | Returns true for arrays /and/ accumulators. -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = True -hasArrays (STLEither a b) = hasArrays a || hasArrays b +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True type family Tup env where Tup '[] = TNil diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 3d5f544..a22b73f 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -1,18 +1,22 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where +module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import AST +import AST.Sparse.Types import Data --- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them --- into their concrete implementations. +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -40,6 +44,9 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) @@ -47,12 +54,17 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -zero SMTNil _ = ENil ext +-- don't destroy the effects! +zero SMTNil e = ELet ext e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) @@ -65,8 +77,30 @@ zero (SMTScal t) _ = case t of STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -plus SMTNil _ _ = ENil ext +-- don't destroy the effects! +plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext plus (SMTPair t1 t2) a b = let t = STPair (fromSMTy t1) (fromSMTy t2) in ELet ext a $ @@ -104,7 +138,7 @@ plus (SMTArr _ t) a b = a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ @@ -142,3 +176,78 @@ onehot typ topprj idx arg = case (typ, topprj) of (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index d882e28..3a97fd1 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -19,6 +19,7 @@ module AST.Weaken (module AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const +import Data.GADT.Compare import Data.Kind (Type) import Data @@ -31,6 +32,11 @@ data Idx env t where IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) +instance GEq (Idx env) where + geq IZ IZ = Just Refl + geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl + geq _ _ = Nothing + splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 6752c24..c6efe37 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil auto :: KnownListSpine list => SList (Const ()) list diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 20575b3..6301dc1 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -28,9 +28,9 @@ data ValId t where VIPair :: ValId a -> ValId b -> ValId (TPair a b) VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value - VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) @@ -244,6 +244,35 @@ idana env expr = case expr of res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds t1 + x2 <- genIds t1 + (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') + EConst _ t val -> do res <- VIScal <$> genId pure (res, EConst res t val) @@ -294,6 +323,10 @@ idana env expr = case expr of res <- genIds t4 pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') + ERecompute _ e -> do + (v, e') <- idana env e + pure (v, ERecompute v e') + EWith _ t e1 e2 -> do let t1 = typeOf e1 (_, e1') <- idana env e1 @@ -303,11 +336,11 @@ idana env expr = case expr of let res = VIPair v2 x2 pure (res, EWith res t e1' e2') - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' e2' e3') + pure (VINil, EAccum VINil t prj e1' sp e2' e3') EZero _ t e1 -> do -- Approximate the result of EZero to be independent from the zero info @@ -316,6 +349,13 @@ idana env expr = case expr of res <- genIds (fromSMTy t) pure (res, EZero res t e1') + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 @@ -367,11 +407,11 @@ genIds :: STy t -> IdGen (ValId t) genIds STNil = pure VINil genIds (STPair a b) = VIPair <$> genIds a <*> genIds b genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId -genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) shidsToVec SZ _ = pure VNil diff --git a/src/Array.hs b/src/Array.hs index 707dce2..6ceb9fe 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -91,6 +91,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) arrayToList :: Array n t -> [t] arrayToList (Array _ v) = V.toList v +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/src/CHAD.hs b/src/CHAD.hs index ac308ac..7594a0f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} @@ -11,6 +12,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -33,15 +35,14 @@ module CHAD ( import Data.Functor.Const import Data.Some -import Data.Type.Bool (If) import Data.Type.Equality (type (==), testEquality) -import GHC.Stack (HasCallStack) import Analysis.Identity (ValId(..), validSplitEither) import AST import AST.Bindings import AST.Count import AST.Env +import AST.Sparse import AST.Weaken.Auto import CHAD.Accum import CHAD.EnvDescr @@ -62,14 +63,20 @@ tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds - -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = EPair ext (EVar ext t (w @> IZ)) - (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = - bindingsCollect binds sub (w .> WSink) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) + +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) -- In order from large to small: i.e. in reverse order from what we want, -- because in a Bindings, the head of the list is the bottom-most entry. @@ -140,7 +147,7 @@ growRecon t ts (Reconstructor unfbs bs) -- Add a 'fst' at the bottom of the builder stack. -- First we have to weaken most of 'bs' to skip one more binding in the -- unfolder stack above it. - (BPush (fst (weakenBindings weakenExpr + (BPush (fst (weakenBindingsE (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) (WSink :: env :> (Tape (t : ts) : env))) bs)) (t @@ -190,14 +197,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) -- incidentally also add a bunch of additional bindings, namely 'Reverse -- (TapeUnfoldings binds)', so the calling code just has to skip those in -- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) - -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = - let Reconstructor unf build = buildReconstructor binds - in (fst $ weakenBindings weakenExpr (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) ---------------------------------- DERIVATIVES --------------------------------- @@ -227,26 +234,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) -> D2Op (TScal a) t @@ -261,11 +279,11 @@ d2op op = case op of -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) STF32 -> float STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -293,7 +311,7 @@ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t @@ -314,67 +332,158 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) = Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" ------------------------------------- MONOIDS ----------------------------------- - -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) +----------------------------------- SPARSITY ----------------------------------- ------------------------------------- SUBENVS ----------------------------------- +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" -subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext (d2M t) - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) + +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) --------------------------------- ACCUMULATORS --------------------------------- @@ -407,8 +516,8 @@ accumPromote :: forall dt env sto proxy r. -- accumulators. -> (forall shbinds. SList STy shbinds - -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) -- ^ A weakening that converts a computation in the -- revised environment to one in the original environment -- extended with some accumulators. @@ -422,14 +531,14 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of k (storepl `DPush` (t, vid, SAccum)) envpro prosub - (SEYes accrevsub) + (SEYesR accrevsub) (VarMap.sink1 accumMap) (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -449,7 +558,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SAccum)) (t `SCons` envpro) - (SEYes prosub) + (SEYesR prosub) (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of @@ -466,7 +575,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do @@ -484,6 +593,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STNil -> True STPair a b -> isDiscrete a && isDiscrete b STEither a b -> isDiscrete a && isDiscrete b + STLEither a b -> isDiscrete a && isDiscrete b STMaybe a -> isDiscrete a STArr _ a -> isDiscrete a STScal st -> case st of @@ -493,26 +603,45 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STF64 -> False STBool -> True STAccum{} -> False - STLEither a b -> isDiscrete a && isDiscrete b ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- -data Ret env0 sto t = - forall shbinds tapebinds env0Merge. +data Ret env0 sto sd t = + forall shbinds tapebinds contribs. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (Ret env0 sto t) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) + +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) -data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. - RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) data Rets env0 sto env list = forall shbinds tapebinds. @@ -521,113 +650,149 @@ data Rets env0 sto env list = (SList (RetPair env0 sto env shbinds tapebinds) list) deriving instance Show (Rets env0 sto env list) +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindings weakenExpr w binds + let (binds', _) = weakenBindingsE w binds in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. Descr env0 sto -> SList f b1 -> SList f b2 -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - d) + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings b) (retConcat descr list) + <- weakenRets (sinkWithBindings e0) (retConcat descr list) , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat b binds) + = Rets (bconcat e0 binds) (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) subtape subtape2) pairs)) freezeRet :: Descr env sto - -> Ret env sto t + -> Ret env sto (D2 t) t -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) + (ELet ext (weakenExpr (autoWeak library (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) ---------------------------- THE CHAD TRANSFORMATION --------------------------- -drev :: forall env sto t. +drev :: forall env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case EVar _ t i -> case conv2Idx des i of Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + (subenvNone (d2e (select SMerge des))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) - (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) Idx2Di _ -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) ELet _ (rhs :: Expr _ _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs - , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in + Ret (bconcat (rhs0 `bpush` rhs1) body0') + (subenvConcat subtapeRHS subtapeBody) (weakenExpr wbody0' body1) subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) @@ -637,317 +802,461 @@ drev des accumMap = \case (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EVar ext (contribTupTy des subRHS) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b - | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil - , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> - subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> Ret binds subtape (EPair ext a1 b1) subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> Ret e0 subtape (EFst ext e1) sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> Ret e0 subtape (ESnd ext e1) sub - (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInl ext (d1 t2) e1) - sub + sub' (ELCase ext - (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) - (zeroTup (subList (select SMerge des) sub)) - (weakenExpr (WCopy WSink) e2) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInr ext (d1 t1) e1) - sub + sub' (ELCase ext - (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) - (zeroTup (subList (select SMerge des) sub)) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy WSink) e2)) + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) ECase _ e (a :: Expr _ _ t) b - | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) - , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollect a0 subtapeA - , let collectB = bindingsCollect b0 subtapeB + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) -> - subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> - subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in - Ret (e0 `BPush` - (tPrimal, - ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) + (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut - (ELet ext + (elet (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #ta0 (subList (bindingsBinds a0) subtapeA) + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA &. #prea0 prerebinds - &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #ta0 :++: #tl) (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) a2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (ELInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ - in letBinds rebinds $ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tb0 (subList (bindingsBinds b0) subtapeB) + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB &. #preb0 prerebinds - &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #tb0 :++: #tl) (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) b2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (ELInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ - ELet ext - (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) EConst _ t val -> Ret BTop SETop (EConst ext t val) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> case d2op op of Linear d2opfun -> Ret e0 subtape (d1op op e1) sub - (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) + Ret (e0 `bpush` e1) + (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - ECustom _ _ _ storety _ pr du a b + ECustom _ _ tb _ srce pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> - Ret (binds `BPush` (typeOf a1, a1) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @sd) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } EError _ t s -> Ret BTop SETop (EError ext (d1 t) s) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EConstArr _ n t val -> Ret BTop SETop (EConstArr ext n t val) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result + | SpArr @_ @sdElt sdElt <- sd , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') -> + let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> - case assertSubenvEmpty sub of { Refl -> + case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 -> + case lemAppendNil @e_binds of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollect e0 subtapeE in - Ret (BTop `BPush` (shty, letBinds she0 she1) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - in EPair ext (weakenExpr w e1) (collectexpr w))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - EMaybe ext - (zeroTup envPro) - (ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim (D2 eltty))) - &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - (EVar ext (d2 (STArr ndim eltty)) IZ)) + let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + Ret (mergePrimalBindings + `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) + `bpush` EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #propr :++: #d1env)) + e0)) $ + let w = autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) + w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) + in EPair ext (weakenExpr w e1) (collectexpr w')) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) + (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in + ESnd ext $ + uninvertTup (d2e envPro) (STArr ndim STNil) $ + makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) + in letBinds (rebinds IZ) $ + weakenExpr (autoWeak (#d (auto1 @sdElt) + &. #pro (d2ace envPro) + &. #etape (subList (bindingsBinds e0) subtapeE) + &. #prerebinds prerebinds + &. #tape (auto1 @(Tape e_tape)) + &. #ix (auto1 @shty) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) + &. #sh (auto1 @shty) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des))) + (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) + .> wPro (subList (bindingsBinds e0) subtapeE)) + e2) }} + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in + case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #primalzip (zipPrimalTy `SCons` SNil) + &. #efPrerebinds efPrerebinds + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a mergePrimalBindings' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in + letBinds (fst (weakenBindingsE (autoWeak library + (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + layout) + ef0)) $ + elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: layout)) + ef1) $ + EPair ext + (evar IZ) + (EPair ext + (evar IZ) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) + (EVar ext (d1 eltty) (IS (IS IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ + makeAccumulators (autoWeak library #propr layout1) envPro $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ + elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ + elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ + letBinds (efRebinds (IS (IS IZ))) $ + let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in + elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ + weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + ef2) $ + EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) + (ezip + (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) + (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + plus_x₀a_f + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + } + EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> Ret e0 subtape (EUnit ext e1) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (ezeroD2 eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (sparsePlus (d2M eltty) sdElt' + (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) + (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e , STArr _ t <- typeOf e -> Ret e0 subtape (EIdx0 ext e1) sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" {- @@ -956,9 +1265,9 @@ drev des accumMap = \case | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) + (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) sub @@ -969,60 +1278,73 @@ drev des accumMap = \case -} EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e , Refl <- indexTupD1Id n - , Refl <- lemZeroInfoD2 eltty - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> - Ret (e0 `BPush` (STArr (SS n) t, e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) + (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" @@ -1033,104 +1355,130 @@ drev des accumMap = \case ELCase{} -> err_unsupported "ELCase" EWith{} -> err_accum - EAccum{} -> err_accum EZero{} -> err_monoid + EDeepZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) -data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. RetScoped (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv shbinds tapebinds) + (Subenv (Append shbinds '[D1 a]) tapebinds) (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) + (SubenvS (D2E (Select env0 sto "merge")) contribs) -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) -- ^ the merge contributions, plus the cotangent to the argument -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) +deriving instance Show (RetScoped env0 sto a s sd t) -drevScoped :: forall a s env sto t. +drevScoped :: forall a s env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd -> Expr ValId (a : env) t - -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of SMerge - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) SAccum - | Just (VIArr i _) <- argids + | chcSmartWith ?config + , Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap , Just Refl <- testEquality foundTy (STAccum (d2M argty)) - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> - RetScoped e0 subtape e1 sub $ + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) + weakenExpr (autoWeak (#d (auto1 @sd) &. #body (subList (bindingsBinds e0) subtape) &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) - -- Our contribution to the binding's cotangent _here_ is - -- zero, because we're contributing to an earlier binding - -- of the same value instead. - (EPair ext e2 (ezeroD2 argty)) + (EPair ext e2 (ENil ext)) | let accumMap' = case argids of Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> - RetScoped e0 subtape e1 sub $ - EWith ext (d2M argty) (ezeroD2 argty) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des))) + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) e2 SDiscr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - RetScoped e0 subtape e1 sub e2 + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) + = mapExt (const ext) e diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index d8a71b5..a7bc53f 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -1,18 +1,59 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD and CHAD.Top. module CHAD.Accum where import AST import CHAD.Types import Data +import AST.Env +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = - makeAccumulators envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) @@ -25,3 +66,7 @@ uninvertTup (t `SCons` list) tcore e = (ESnd ext (EVar ext recT IZ)) (ESnd ext (EFst ext (EVar ext recT IZ)))) +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs index 4c287d7..49ae0e6 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/EnvDescr.hs @@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of @@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des select s@SAccum (DPush des (_, _, SDiscr)) = select s des select s@SMerge (DPush des (_, _, SDiscr)) = select s des select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 9e7e7f5..4814bdf 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -12,6 +12,8 @@ module CHAD.Top where import Analysis.Identity import AST +import AST.Env +import AST.Sparse import AST.SplitLets import AST.Weaken.Auto import CHAD @@ -41,39 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) - else k (des `DPush` (t, Nothing, SMerge)) - -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto + -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, _, SAccum)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) - (ESnd ext (EVar ext (typeOf e) IZ)))) - (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, _, SMerge)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) - (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) - (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) @@ -83,21 +71,22 @@ chad config env (term :: Ex env t) let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (select SAccum descr) $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr VarMap.empty term')) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term') + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') where term' = identityAnalysis env (splitLets term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 74e7dbd..44ac20e 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -1,8 +1,10 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Types where +import AST.Accum import AST.Types import Data @@ -11,19 +13,19 @@ type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t - D1 (TLEither a b) = TLEither (D1 a) (D1 b) type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) + D2 (TPair a b) = TPair (D2 a) (D2 b) D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TMaybe (TArr n (D2 t)) + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t - D2 (TLEither a b) = TLEither (D2 a) (D2 b) type family D2s t where D2s TI32 = TNil @@ -48,11 +50,11 @@ d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" -d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil @@ -60,10 +62,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env d2M :: STy t -> SMTy (D2 t) d2M STNil = SMTNil -d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STPair a b) = SMTPair (d2M a) (d2M b) d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2M (STMaybe t) = SMTMaybe (d2M t) -d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STArr n t) = SMTArr n (d2M t) d2M (STScal t) = case t of STI32 -> SMTNil STI64 -> SMTNil @@ -71,7 +74,6 @@ d2M (STScal t) = case t of STF64 -> SMTScal STF64 STBool -> SMTNil d2M STAccum{} = error "Accumulators not allowed in input program" -d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2 :: STy t -> STy (D2 t) d2 = fromSMTy . d2M @@ -95,6 +97,8 @@ data CHADConfig = CHADConfig chcCaseArrayAccum :: Bool , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool } deriving (Show) @@ -103,12 +107,14 @@ defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False , chcArgArrayAccum = False + , chcSmartWith = False } chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True - , chcArgArrayAccum = True } + , chcArgArrayAccum = True + , chcSmartWith = True } ------------------------------------ LEMMAS ------------------------------------ @@ -116,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 87c01cb..888fed4 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -19,29 +19,25 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" - STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" - STScal sty -> case sty of - STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der - STAccum{} -> error "Accumulators not allowed in input program" STLEither t1 t2 -> case (primal, der) of (_, Nothing) -> Nothing (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) _ -> error "Primal and cotangent disagree on LEither alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Compile.hs b/src/Compile.hs index 6ba3a39..f2063ee 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -34,6 +34,7 @@ import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) @@ -45,6 +46,7 @@ import qualified Prelude import Array import AST import AST.Pretty (ppSTy, ppExpr) +import AST.Sparse.Types (isDense) import Compile.Exec import Data import Interpreter.Rep @@ -77,7 +79,7 @@ compile = \env expr -> do let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source ["kernel"] + lib <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type @@ -86,7 +88,7 @@ compile = \env expr -> do allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) serialiseArguments args ptr $ do - callKernelFun "kernel" lib ptr + callKernelFun lib ptr ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) when (ok /= 1) $ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) @@ -125,7 +127,7 @@ data CExpr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ (<type)<expr> + | CECast String CExpr -- ^ (<type>)<expr> deriving (Show) printStructDecl :: StructDecl -> ShowS @@ -214,23 +216,31 @@ repSTy (STScal st) = case st of STBool -> "uint8_t" repSTy t = genStructName t -genStructName :: STy t -> String -genStructName = \t -> "ty_" ++ gen t where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen (fromSMTy t) - gen (STLEither a b) = 'L' : gen a ++ gen b +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the @@ -238,60 +248,56 @@ genStructName = \t -> "ty_" ++ gen t where -- -- For accumulation it is important that for struct representations of monoid -- types, the all-zero-bytes value corresponds to the zero value of that type. -genStruct :: String -> STy t -> [StructDecl] -genStruct name topty = case topty of +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of STNil -> - [StructDecl name "" com] + TreeNode [StructDecl name "" com] [] STPair a b -> - [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] STEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - -- TODO: put shape in the main struct, not the buffer; it's constant, after all -- TODO: no buffer if n = 0 - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] STScal _ -> - [] + TreeNode [] [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] - STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] where + name = genStructName topty com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () -genStructs ty = do - let name = genStructName ty - seen <- lift $ gets (name `Set.member`) - - if seen - then pure () - else do - -- already mark this struct as generated now, so we don't generate it - -- twice (unnecessary because no recursive types, but y'know) - lift $ modify (Set.insert name) - - () <- case ty of - STNil -> pure () - STPair a b -> genStructs a >> genStructs b - STEither a b -> genStructs a >> genStructs b - STMaybe t -> genStructs t - STArr _ t -> genStructs t - STScal _ -> pure () - STAccum t -> genStructs (fromSMTy t) - STLEither a b -> genStructs a >> genStructs b - - tell (BList (genStruct name ty)) +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) data CompState = CompState { csStructs :: Set (Some STy) @@ -340,6 +346,12 @@ emitStruct ty = CompM $ do modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } return (genStructName ty) +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + emitTLD :: String -> CompM () emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } @@ -463,6 +475,15 @@ serialise topty topval ptr off k = (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k (STMaybe _, Nothing) -> do pokeByteOff ptr off (0 :: Word8) k @@ -471,19 +492,18 @@ serialise topty topval ptr off k = serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t - allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do when debugRefc $ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh - pokeShape bufptr 0 n sh - pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63) + pokeByteOff @Word64 bufptr 0 (2 ^ 63) - let off1 = fromSNat n * 8 + 8 - loop i + let loop i | i == shapeSize sh = k | otherwise = - serialise t (vec V.! i) bufptr (off1 + i * eltsz) $ + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ loop (i+1) loop 0 (STScal sty, x) -> case sty of @@ -493,15 +513,6 @@ serialise topty topval ptr off k = STF64 -> pokeByteOff ptr off (x :: Double) >> k STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k (STAccum{}, _) -> error "Cannot serialise accumulators" - (STLEither _ _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STLEither a _, Just (Left x)) -> do - pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STLEither _ b, Just (Right y)) -> do - pokeByteOff ptr off (2 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k -- | Assumes that this is called at the correct alignment. deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -518,6 +529,13 @@ deserialise topty ptr off = if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" STMaybe t -> do tag <- peekByteOff @Word8 ptr off if tag == 0 @@ -525,13 +543,12 @@ deserialise topty ptr off = else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape bufptr 0 n - refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 when debugRefc $ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let off1 = 8 * fromSNat n + 8 - eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) when (refc < 2 ^ 62) $ free bufptr return arr STScal sty -> case sty of @@ -541,13 +558,6 @@ deserialise topty ptr off = STF64 -> peekByteOff @Double ptr off STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off STAccum{} -> error "Cannot serialise accumulators" - STLEither a b -> do - tag <- peekByteOff @Word8 ptr off - case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) - 0 -> return Nothing - 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) - 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) - _ -> error "Invalid tag value" align :: Int -> Int -> Int align a off = (off + a - 1) `div` a * a @@ -569,10 +579,14 @@ metricsSTy (STEither a b) = let (a1, s1) = metricsSTy a (a2, s2) = metricsSTy b in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr _ _) = (8, 8) +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) metricsSTy (STScal sty) = case sty of STI32 -> (4, 4) STI64 -> (8, 8) @@ -580,10 +594,6 @@ metricsSTy (STScal sty) = case sty of STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t metricsSTy (STAccum t) = metricsSTy (fromSMTy t) -metricsSTy (STLEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -747,15 +757,17 @@ compile' env = \case return (CELit retvar) EConstArr _ n t (Array sh vec) -> do - strname <- emitStruct (STArr n (STScal t)) + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" -- Give it a refcount of _half_ the size_t max, so that it can be -- incremented and decremented at will and will "never" reach anything -- where something happens - emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++ - "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ - ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) EBuild _ n esh efun -> do shname <- compileAssign "sh" env esh @@ -770,7 +782,7 @@ compile' env = \case emit $ SBlock $ pure (SVarDecl False "size_t" linivar (CELit "0")) <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx)))) + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) @@ -799,7 +811,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name @@ -845,7 +857,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) let vecwid = 8 :: Int ivar <- genName' "i" @@ -909,6 +921,136 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -934,7 +1076,7 @@ compile' env = \case when emitChecks $ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ arrname ++ ".buf); return false;") @@ -977,6 +1119,8 @@ compile' env = \case maybe (return ()) ($ name2) mfun2 return (CELit name) + ERecompute _ e -> compile' env e + EWith _ t e1 e2 -> do actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 @@ -1000,95 +1144,7 @@ compile' env = \case rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - EAccum _ t prj eidx eval eacc -> do - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a - -- full zero array with the given zero info (for the type SMTArr n t1). - initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () - initZeroArray n t1 v vzi = do - shszname <- genName' "inacshsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) - newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) - emit $ SAsg v (CELit newarrName) - forM_ (initZeroFromMemset t1) $ \f1 -> do - ivar <- genName' "i" - ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts - - -- If something needs to be done to properly initialise this type to - -- zero after memory has already been initialised to all-zero bytes, - -- returns an action that does so. - -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) - initZeroFromMemset SMTNil = Nothing - initZeroFromMemset (SMTPair t1 t2) = - case (initZeroFromMemset t1, initZeroFromMemset t2) of - (Nothing, Nothing) -> Nothing - (mf1, mf2) -> Just $ \v vzi -> do - forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") - forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") - initZeroFromMemset SMTLEither{} = Nothing - initZeroFromMemset SMTMaybe{} = Nothing - initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi - initZeroFromMemset SMTScal{} = Nothing - - let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroZI :: SMTy a -> String -> String -> CompM () - initZeroZI SMTNil _ _ = return () - initZeroZI (SMTPair t1 t2) v vzi = do - initZeroZI t1 (v++".a") (vzi++".a") - initZeroZI t2 (v++".b") (vzi++".b") - initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZeroZI (SMTScal sty) v _ = case sty of - STI32 -> emit $ SAsg v (CELit "0") - STI64 -> emit $ SAsg v (CELit "0l") - STF32 -> emit $ SAsg v (CELit "0.0f") - STF64 -> emit $ SAsg v (CELit "0.0") - - let -- Initialise an uninitialised accumulation value, potentially already - -- with the addend, potentially to zero depending on the nature of the - -- projection. - -- 1. If the projection indexes only through dense monoids before - -- reaching SAPHere, the thing cannot be initialised to zero with - -- only an AcIdx; it would need to model a zero after the addend, - -- which is stupid and redundant. In this case, we return Left: - -- (accumulation value) (AcIdx value) (addend value). - -- The addend is copied, not consumed. (We can't reliably _always_ - -- consume it, so it's not worth trying to do it sometimes.) - -- 2. Otherwise, a sparse monoid is found along the way, and we can - -- initalise the dense prefix of the path to zero by setting the - -- indexed-through sparse value to a sparse zero. Afterwards, the - -- main recursion can proceed further. In this case, we return - -- Right: (accumulation value) (AcIdx value) - -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) - initZeroChunk :: SMTy a -> SAcPrj p a b - -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend - (String -> String -> CompM ()) -- zero initialisation of sparse chunk - initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of - -- reached target before the first sparse constructor - (t1 , SAPHere ) -> Left $ \v _ addend -> do - incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend - emit $ SAsg v (CELit addend) - -- sparse types - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - -- dense types - (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - f (v++".a") (i++".a") - initZeroZI t2 (v++".b") (i++".b") - (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do - initZeroZI t1 (v++".a") (i++".a") - f (v++".b") (i++".b") - (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - initZeroArray n t1 v (i++".a.b") - linidxvar <- genName' "li" - emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) - f (v++".buf->xs["++linidxvar++"]") (i++".b") - where - applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i - applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i - + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do let -- Add a value (s) into an existing accumulation value (d). If a sparse -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () @@ -1129,16 +1185,16 @@ compile' env = \case when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ [0 .. fromSNat n - 1] $ \j -> do - emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) "!=" - (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (CELit (d ++ ".sh[" ++ show j ++ "]"))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ d ++ ".buf" ++ - concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ ", " ++ s ++ ".buf" ++ - concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ "); " ++ "return false;") mempty @@ -1158,67 +1214,55 @@ compile' env = \case accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend - accumRef (SMTLEither ta tb) prj0 v i addend = do - let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' - SAPRight prj' -> initZeroChunk tb prj' - subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") - tagval = case prj0 of SAPLeft{} -> "1" - SAPRight{} -> "2" - ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend - SAPRight prj' -> accumRef tb prj' subv i addend - case chunkres of - Left densef -> do - ((), stmtsSet) <- scope $ densef subv i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) - stmtsAdd -- TODO: emit check for consistency of tags? - Right sparsef -> do - ((), stmtsInit) <- scope $ sparsef subv i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty - forM_ stmtsAdd emit + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - case initZeroChunk tj prj' of - Left densef -> do - ((), stmtsSet1) <- scope $ densef (v++".j") i addend - ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) - stmtsAdd1 - Right sparsef -> do - ((), stmtsInit1) <- scope $ sparsef (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i addend + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend nameidx <- compileAssign "acidx" env eidx nameval <- compileAssign "acval" env eval @@ -1232,6 +1276,9 @@ compile' env = \case return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] @@ -1245,6 +1292,7 @@ compile' env = \case return $ CEStruct name [] EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" @@ -1303,13 +1351,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop - (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1361,21 +1409,21 @@ toLinearIdx SZ _ _ = CELit "0" toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") toLinearIdx (SS n) arrvar idxvar = CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n))))) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) "+" (CELit (idxvar ++ ".b")) -- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr -- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] -- fromLinearIdx (SS n) arrvar idxvar = do -- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) -- _ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" @@ -1390,9 +1438,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr])] - forM_ (zip shape [0::Int ..]) $ \(dim, i) -> - emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") when debugRefc $ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" @@ -1403,16 +1450,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] compileShapeQuery (SS n) var = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) -- | Takes a variable name for the array, not the buffer. compileArrShapeComponents :: SNat n -> String -> [CExpr] compileArrShapeComponents n var = - [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1431,6 +1478,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] @@ -1503,7 +1553,7 @@ compileExtremum nameBase opName operator env e = do lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" @@ -1574,7 +1624,7 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (hasArrays (fromSMTy t)) -> do + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" emit $ SVarDeclUninit toptyname name @@ -1583,7 +1633,7 @@ copyForWriting topty var = case topty of let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" emit $ SVerbatim $ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ ");" emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) @@ -1594,8 +1644,7 @@ copyForWriting topty var = case topty of in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ printCExpr 0 databytes ");"]) @@ -1612,8 +1661,7 @@ copyForWriting topty var = case topty of name <- genName emit $ SVarDecl False toptyname name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain @@ -1657,6 +1705,14 @@ zeroRefcountCheck toptyp opname topvar = go (STEither a b) path = do (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) go (STMaybe a) path = do ss <- go a (path++".j") return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty @@ -1673,14 +1729,6 @@ zeroRefcountCheck toptyp opname topvar = return (BList [s1, s2, s3]) go STScal{} _ = empty go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" - go (STLEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ - SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) - s1 - (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) - s2 - mempty)) combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) combine (MaybeT a) (MaybeT b) = MaybeT $ do diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index d708fc0..cc6d5fa 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -11,8 +11,6 @@ module Compile.Exec ( import Control.Monad (when) import Data.IORef -import qualified Data.Map.Strict as Map -import Data.Map.Strict (Map) import Foreign (Ptr) import Foreign.Ptr (FunPtr) import System.Directory (removeDirectoryRecursive) @@ -30,10 +28,10 @@ debug :: Bool debug = False -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) -data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ())))) +data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) -buildKernel :: String -> [String] -> IO KernelLib -buildKernel csource funnames = do +buildKernel :: String -> String -> IO KernelLib +buildKernel csource funname = do template <- (++ "/tmp.chad.") <$> getTempDir path <- mkdtemp template @@ -44,7 +42,8 @@ buildKernel csource funnames = do ,"-o", outso, "-" ,"-Wall", "-Wextra" ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" - ,"-Wno-unused-parameter", "-Wno-unused-function"] + ,"-Wno-unused-parameter", "-Wno-unused-function" + ,"-Wno-alloc-size-larger-than"] -- ideally we'd keep this, but gcc reports false positives (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource -- Print the source before the GCC output. @@ -69,8 +68,7 @@ buildKernel csource funnames = do removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now - ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames] - ref <- newIORef ptrs + ref <- newIORef =<< dlsym dl funname _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" dlclose dl) @@ -81,10 +79,10 @@ foreign import ccall "dynamic" -- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive {-# NOINLINE callKernelFun #-} -callKernelFun :: String -> KernelLib -> Ptr () -> IO () -callKernelFun key (KernelLib ref) arg = do - mp <- readIORef ref - wrapKernelFun (mp Map.! key) arg +callKernelFun :: KernelLib -> Ptr () -> IO () +callKernelFun (KernelLib ref) arg = do + ptr <- readIORef ref + wrapKernelFun ptr arg getTempDir :: IO FilePath getTempDir = diff --git a/src/Data.hs b/src/Data.hs index e86aaa6..e6978c8 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -8,12 +8,13 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl)) where +module Data (module Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare import Data.GADT.Show import Data.Some +import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) @@ -184,3 +185,8 @@ instance Applicative Bag where instance Semigroup (Bag t) where (<>) = BTwo instance Monoid (Bag t) where mempty = BNone + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs index 9c10421..2712b08 100644 --- a/src/Data/VarMap.hs +++ b/src/Data/VarMap.hs @@ -74,7 +74,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' subMap subenv = let bools = let loop :: Subenv env env' -> [Bool] loop SETop = [] - loop (SEYes sub) = True : loop sub + loop (SEYesR sub) = True : loop sub loop (SENo sub) = False : loop sub in VS.fromList $ loop subenv newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools @@ -89,7 +89,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env superMap subenv = let loop :: Subenv env env' -> Int -> [Int] loop SETop _ = [] - loop (SEYes sub) i = i : loop sub (i+1) + loop (SEYesR sub) i = i : loop sub (i+1) loop (SENo sub) i = loop sub (i+1) newIndices = VS.fromList $ loop subenv 0 diff --git a/src/Example.hs b/src/Example.hs index 3623d03..2c51291 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -5,13 +5,18 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} + +{-# OPTIONS -Wno-unused-imports #-} module Example where import Array import AST +import AST.Count import AST.Pretty +import AST.UnMonoid import CHAD import CHAD.Top +import CHAD.Types import ForwardAD import Interpreter import Language @@ -24,6 +29,22 @@ import Example.Types -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) +pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +pipeline config term + | Dict <- styKnown (d2 (typeOf term)) = + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + chad' config knownEnv $ + simplifyFix $ + term + +-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures +pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () +pipeline' config term + | Dict <- styKnown (d2 (typeOf term)) = + pprintExpr (pipeline config term) + + bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) @@ -159,8 +180,18 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - _ -> undefined + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) + +-- The build body uses free variables in a non-linear way, so their primal +-- values are required in the dual of the build. Thus, compositionally, they +-- are stored in the tape from each individual lambda invocation. This results +-- in n copies of y and z, where only one copy would have sufficed. +exUniformFree :: Ex '[R, I64] R +exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ + let_ #y (#x * 2) $ + let_ #z (#x * 3) $ + idx0 $ sum1i $ + build1 #n $ #i :-> #y * #z + toFloat_ #i diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 5756f96..b353def 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -26,10 +26,10 @@ type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t - Tan (TLEither a b) = TLEither (Tan a) (Tan b) type family TanS t where TanS TI32 = TNil @@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t) tanty STNil = STNil tanty (STPair a b) = STPair (tanty a) (tanty b) tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanty (STMaybe t) = STMaybe (tanty t) tanty (STArr n t) = STArr n (tanty t) tanty (STScal t) = case t of @@ -55,7 +56,6 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" -tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanenv :: SList STy env -> SList STy (TanE env) tanenv SNil = SNil @@ -66,6 +66,9 @@ zeroTan STNil () = () zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) zeroTan (STEither a _) (Left x) = Left (zeroTan a x) zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) zeroTan (STMaybe _) Nothing = Nothing zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y tanScalars (STEither a _) (Left x) = tanScalars a x tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x @@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) = unzipDN (STEither a b) d = case d of Left d1 -> bimap Left Left (unzipDN a d1) Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) unzipDN (STMaybe t) d = case d of Nothing -> (Nothing, Nothing) Just d' -> bimap Just Just (unzipDN t d') @@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" -unzipDN (STLEither a b) d = case d of - Nothing -> (Nothing, Nothing) - Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) - Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of (Left x', Left y') -> dotprodTan a x' y' (Right x', Right y') -> dotprodTan b x' y' _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" dotprodTan (STMaybe t) x y = case (x, y) of (Nothing, Nothing) -> 0.0 (Just x', Just y') -> dotprodTan t x' y' @@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -dotprodTan (STLEither a b) x y = case (x, y) of - (Nothing, _) -> 0.0 -- 0 * y = 0 - (_, Nothing) -> 0.0 -- x * 0 = 0 - (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' - (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible LEither alternatives" -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t) dnConst STNil = const () dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) dnConst (STMaybe t) = fmap (dnConst t) dnConst (STArr _ t) = arrayMap (dnConst t) dnConst (STScal t) = case t of @@ -196,7 +197,6 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e = case e of Left x -> \f -> Left (dnOnehots t1 x (f . Left)) Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnOnehots (STMaybe t) m = case m of Nothing -> \_ -> Nothing @@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" -dnOnehots (STLEither t1 t2) e = - case e of - Nothing -> \_ -> Nothing - Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) - Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index ebc70d7..44bdbb2 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -168,6 +168,8 @@ dfwdDN = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) @@ -185,16 +187,22 @@ dfwdDN = \case ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) + ERecompute _ e -> dfwdDN e EError _ t s -> EError ext (dn t) s EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) @@ -218,4 +226,4 @@ dfwdDN = \case (EFst ext (EVar ext tIxN (IS IZ))))))) (ESnd ext (EVar ext t2 (IS IZ))) (zeroScalarConst t)))) - (EMaximum1Inner ext (dfwdDN e)) + (extremum (dfwdDN e)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs index 3c76cbe..dcacf5f 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -12,10 +12,10 @@ type family DN t where DN TNil = TNil DN (TPair a b) = TPair (DN a) (DN b) DN (TEither a b) = TEither (DN a) (DN b) + DN (TLEither a b) = TLEither (DN a) (DN b) DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t - DN (TLEither a b) = TLEither (DN a) (DN b) type family DNS t where DNS TF32 = TPair (TScal TF32) (TScal TF32) @@ -32,6 +32,7 @@ dn :: STy t -> STy (DN t) dn STNil = STNil dn (STPair a b) = STPair (dn a) (dn b) dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b) dn (STMaybe t) = STMaybe (dn t) dn (STArr n t) = STArr n (dn t) dn (STScal t) = case t of @@ -41,7 +42,6 @@ dn (STScal t) = case t of STI64 -> STScal STI64 STBool -> STScal STBool dn STAccum{} = error "Accum in source program" -dn (STLEither a b) = STLEither (dn a) (dn b) dne :: SList STy env -> SList STy (DNE env) dne SNil = SNil diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d7916d8..9e3d2a6 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,12 +21,16 @@ module Interpreter ( ) where import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) +import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity import qualified Data.Functor.Product as Product import Data.Int (Int64) import Data.IORef +import Data.Tuple (swap) import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) @@ -35,6 +39,7 @@ import Debug.Trace import Array import AST import AST.Pretty +import AST.Sparse.Types import Data import Interpreter.Rep @@ -141,6 +146,43 @@ interpret'Rec env = \case sh `ShCons` n = arrayShape arr numericIsNum t $ return $ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) @@ -153,18 +195,22 @@ interpret'Rec env = \case e1' <- interpret' env e1 e2' <- interpret' env e2 interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr + ERecompute _ e -> interpret' env e EWith _ t e1 e2 -> do initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do + EAccum _ t p e1 sp e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t p accum idx val + accumAddSparseD t p accum idx sp val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b @@ -215,6 +261,19 @@ zeroM typ zi = case typ of STF32 -> 0.0 STF64 -> 0.0 +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + addM :: SMTy t -> Rep t -> Rep t -> Rep t addM typ a b = case typ of SMTNil -> () @@ -238,7 +297,7 @@ addM typ a b = case typ of | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b -onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) @@ -255,15 +314,6 @@ withAccum t _ initval f = AcM $ do val <- readAc t accum return (out, val) -newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) -newAcZero typ zi = case typ of - SMTNil -> return () - SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi - SMTLEither{} -> newIORef Nothing - SMTMaybe _ -> newIORef Nothing - SMTArr _ t -> arrayMapM (newAcZero t) zi - SMTScal sty -> numericIsNum sty $ newIORef 0 - newAcDense :: SMTy a -> Rep a -> IO (RepAc a) newAcDense typ val = case typ of SMTNil -> return () @@ -273,26 +323,10 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (_, SAPHere) -> newAcDense typ val - - (SMTPair t1 t2, SAPFst prj') -> - (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) - (SMTPair t1 t2, SAPSnd prj') -> - (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val - - (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx - onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ziarr @@ -308,54 +342,67 @@ readAc typ val = case typ of SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val -accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () -accumAddDense typ ref val = case typ of - SMTNil -> return () - SMTPair t1 t2 -> do - accumAddDense t1 (fst ref) (fst val) - accumAddDense t2 (snd ref) (snd val) - SMTLEither{} -> - case val of - Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - SMTMaybe{} -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - SMTArr _ t1 -> - forM_ [0 .. arraySize ref - 1] $ \i -> - accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) - SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) - -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref val +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val - (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) (SMTArr n t1, SAPArrIdx prj') -> - let ((arrindex', ziarr), idx') = idx + let (arrindex', idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = arrayShape ziarr + arrsh = arrayShape ref linindex = toLinearIndex arrsh arrindex - in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) +-- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = -- Try modifying what's already in ref. The 'join' makes the snd @@ -404,3 +451,11 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 070ba4c..1682303 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -22,11 +22,11 @@ type family Rep t where Rep TNil = () Rep (TPair a b) = (Rep a, Rep b) Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) Rep (TMaybe t) = Maybe (Rep t) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty Rep (TAccum t) = RepAc t - Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) -- Mutable, represents monoid types t. type family RepAc t where @@ -56,6 +56,9 @@ showValue _ STNil () = showString "()" showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x showValue d (STArr _ t) arr = showParen (d > 10) $ @@ -70,9 +73,6 @@ showValue d (STScal sty) x = case sty of STI64 -> showsPrec d x STBool -> showsPrec d x showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">" -showValue _ (STLEither _ _) Nothing = showString "LNil" -showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x -showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" @@ -86,6 +86,9 @@ rnfRep STNil () = () rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y rnfRep (STEither a _) (Left x) = rnfRep a x rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y rnfRep (STMaybe _) Nothing = () rnfRep (STMaybe t) (Just x) = rnfRep t x rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = @@ -97,9 +100,6 @@ rnfRep (STScal t) x = case t of STF64 -> rnf x STBool -> rnf x rnfRep STAccum{} _ = error "Cannot rnf accumulators" -rnfRep (STLEither _ _) Nothing = () -rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x -rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y instance KnownTy t => NFData (Value t) where rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/Language.hs b/src/Language.hs index 9fd5dd3..31b4b87 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -17,6 +17,7 @@ module Language ( import Array import AST +import AST.Sparse.Types import AST.Types import CHAD.Types import Data @@ -129,6 +130,17 @@ maximum1i e = NEMaximum1Inner e minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) minimum1i e = NEMinimum1Inner e +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3 + +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy @@ -169,11 +181,17 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) with a (n :-> b) = NEWith (knownMTy @t) a n b -accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownMTy p a b c +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 8bcb5e5..c9d05c9 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import AST.Sparse.Types import CHAD.Types import Data @@ -55,6 +56,16 @@ data NExpr env t where NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + + NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -71,9 +82,12 @@ data NExpr env t where -> NExpr env a -> NExpr env b -> NExpr env t + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a @@ -201,6 +215,10 @@ fromNamedExpr val = \case NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) + + NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) @@ -215,9 +233,10 @@ fromNamedExpr val = \case (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s diff --git a/src/Simplify.hs b/src/Simplify.hs index f5eb0a1..b89d7f6 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE QuasiQuotes #-} @@ -19,13 +21,14 @@ import Control.Monad (ap) import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) import Debug.Trace import AST import AST.Count import AST.Pretty +import AST.Sparse.Types +import AST.UnMonoid (acPrjCompose) import Data import Simplify.TH @@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id) smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) smReconstruct core = SM (\ctx -> (Any False, ctx core)) -tellActed :: SM tenv tt env t () -tellActed = SM (\_ -> (Any True, ())) +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) -- more convenient in practice -acted :: SM tenv tt env t a -> SM tenv tt env t a +acted :: ActedMonad m => m a -> m a acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) -acted' :: (Any, a) -> (Any, a) -acted' (_, x) = (Any True, x) - -liftActed :: (Any, a) -> SM tenv tt env t a -liftActed pair = SM $ \_ -> pair - simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do @@ -167,10 +176,10 @@ simplify'Rec = \case ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 (ELet _ rhs body) acc -> + EAccum _ t p e1 sp (ELet _ rhs body) acc -> acted $ simplify' $ ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) -- let () = e in () ~> e ELet _ e1 (ENil _) | STNil <- typeOf e1 -> @@ -183,10 +192,20 @@ simplify'Rec = \case ESnd _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (ESnd ext e2) (ESnd ext e3) + EFst _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (EFst ext e1) (EFst ext e2) e3 + ESnd _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 -- TODO: more array indexing - EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) - EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 + EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) + EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 + + -- TODO: more array shape + EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) @@ -216,23 +235,40 @@ simplify'Rec = \case acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules - EAccum _ t p e1 e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') (acted $ return (ENil ext)) - (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\e -> acted $ return e) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -278,6 +314,9 @@ simplify'Rec = \case EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EReshape _ n a b -> [simprec| EReshape ext n *a *b |] + EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] + EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] @@ -291,24 +330,18 @@ simplify'Rec = \case e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) pure (ECustom ext s t p a' b' c' e1' e2') + ERecompute _ e -> [simprec| ERecompute ext *e |] EWith _ t e1 e2 -> do e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') - EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b + -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] + -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - EUnit _ e -> cheapExpr e - _ -> False - -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. hasAdds :: Expr x env t -> Bool @@ -337,6 +370,9 @@ hasAdds = \case EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e + EReshape _ _ a b -> hasAdds a || hasAdds b + EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e @@ -345,8 +381,10 @@ hasAdds = \case EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b - EAccum _ _ _ _ _ _ -> True + ERecompute _ e -> hasAdds e + EAccum _ _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -359,57 +397,167 @@ checkAccumInScope = \case SNil -> False check STNil = False check (STPair s t) = check s || check t check (STEither s t) = check s || check t + check (STLEither s t) = check s || check t check (STMaybe t) = check t check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True - check (STLEither s t) = check s || check t -data OneHotTerm env p a b where - OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht -simplifyOneHotTerm :: OneHotTerm env p a b - -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) - -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) - -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do - val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 - case val1' of - EZero{} -> kzero - EOneHot _ t2 prj2 idx2 val2 - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do - tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k - _ -> case prj1 of - SAPHere -> ktriv val1 - _ -> k (OneHotTerm t1 prj1 idx1 val1) +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' (a', b') -> return $ EPair ext a' b' recogniseMonoid typ@(SMTLEither t1 t2) expr = case expr of - ELNil{} -> acted' $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e _ -> return expr recogniseMonoid typ@(SMTMaybe t1) expr = case expr of - ENothing{} -> acted' $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e _ -> return expr recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted' $ do + acted $ do e' <- recogniseMonoid t e return $ ELet ext e' $ @@ -418,59 +566,33 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = (ENil ext)) (EVar ext (fromSMTy t) IZ) recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) - (SMTPair a _, SAPFst prj1') -> - concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SMTPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - - (SMTLEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (SMTLEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (SMTMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - (SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs index 2e0076a..03a74de 100644 --- a/src/Simplify/TH.hs +++ b/src/Simplify/TH.hs @@ -3,7 +3,7 @@ module Simplify.TH (simprec) where import Data.Bifunctor (first) import Data.Char -import Data.List (foldl1') +import Data.List (foldl', foldl1') import Language.Haskell.TH import Language.Haskell.TH.Quote import Text.ParserCombinators.ReadP diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index e0dc4b3..5ceb866 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -2,11 +2,13 @@ {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TupleSections #-} module Test.Framework ( TestTree, testGroup, - testGroupCollapse, + groupSetCollapse, testProperty, withResource, withResource', @@ -18,43 +20,68 @@ module Test.Framework ( TestName, ) where -import Control.Monad (forM, when) -import Control.Monad.Trans.State.Strict +import Control.Concurrent (setNumCapabilities, forkIO, killThread, forkOn) +import Control.Concurrent.MVar +import Control.Concurrent.STM +import Control.Exception (finally) +import Control.Monad (forM, when, forM_, replicateM_) import Control.Monad.IO.Class +import Data.IORef import Data.List (isInfixOf, intercalate) import Data.Maybe (isJust, mapMaybe, fromJust) +import Data.PQueue.Prio.Min qualified as PQ import Data.String (fromString) import Data.Time.Clock -import System.Environment +import GHC.Conc (getNumProcessors) +import System.Console.ANSI qualified as ANSI +import System.Console.Concurrent (outputConcurrent) +import System.Console.Regions +import System.Environment (getArgs, getProgName) import System.Exit import System.IO (hFlush, hPutStrLn, stdout, stderr, hIsTerminalDevice) import Text.Read (readMaybe) -import qualified Hedgehog as H -import qualified Hedgehog.Internal.Config as H -import qualified Hedgehog.Internal.Property as H -import qualified Hedgehog.Internal.Report as H -import qualified Hedgehog.Internal.Runner as H -import qualified Hedgehog.Internal.Seed as H.Seed +import Hedgehog qualified as H +import Hedgehog.Internal.Config qualified as H +import Hedgehog.Internal.Property qualified as H +import Hedgehog.Internal.Report qualified as H +import Hedgehog.Internal.Runner qualified as H +import Hedgehog.Internal.Seed qualified as H.Seed +type TestName = String + data TestTree - = Group Bool String [TestTree] - | forall a. Resource (IO a) (a -> IO ()) (a -> TestTree) + = Group GroupOpts String [TestTree] + | forall a. Resource String (IO a) (a -> IO ()) (a -> TestTree) + -- ^ Name is not specified by user, but inherited from the tree below | HP String H.Property -type TestName = String +-- Not exported because a Resource is not supposed to have a name in the first place +treeName :: TestTree -> String +treeName (Group _ name _) = name +treeName (Resource name _ _ _) = name +treeName (HP name _) = name + +data GroupOpts = GroupOpts + { goCollapse :: Bool } + deriving (Show) + +defaultGroupOpts :: GroupOpts +defaultGroupOpts = GroupOpts False testGroup :: String -> [TestTree] -> TestTree -testGroup = Group False +testGroup = Group defaultGroupOpts -testGroupCollapse :: String -> [TestTree] -> TestTree -testGroupCollapse = Group True +groupSetCollapse :: TestTree -> TestTree +groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name trees +groupSetCollapse _ = error "groupSetCollapse: not called on a Group" --- | The @a -> TestTree@ function must use the @a@ only inside properties: when --- not actually running properties, it will be passed 'undefined'. +-- | The @a -> TestTree@ function must use the @a@ only inside properties: the +-- functoin will be passed 'undefined' when exploring the test tree (without +-- running properties). withResource :: IO a -> (a -> IO ()) -> (a -> TestTree) -> TestTree -withResource = Resource +withResource make cleanup fun = Resource (treeName (fun undefined)) make cleanup fun -- | Same caveats as 'withResource'. withResource' :: IO a -> (a -> TestTree) -> TestTree @@ -66,14 +93,14 @@ testProperty = HP filterTree :: Options -> TestTree -> Maybe TestTree filterTree (Options { optsPattern = pat }) = go [] where - go path (Group collapse name trees) = + go path (Group opts name trees) = case mapMaybe (go (name:path)) trees of [] -> Nothing - trees' -> Just (Group collapse name trees') - go path (Resource make free fun) = + trees' -> Just (Group opts name trees') + go path (Resource inhname make free fun) = case go path (fun undefined) of Nothing -> Nothing - Just _ -> Just $ Resource make free (fromJust . go path . fun) + Just _ -> Just $ Resource inhname make free (fromJust . go path . fun) go path hp@(HP name _) | pat `isInfixOf` renderPath (name:path) = Just hp | otherwise = Nothing @@ -84,9 +111,12 @@ computeMaxLen :: TestTree -> Int computeMaxLen = go 0 where go :: Int -> TestTree -> Int - go indent (Group True name trees) = maximum (2*indent + length name : map (go (indent+1)) trees) - go indent (Group False _ trees) = maximum (0 : map (go (indent+1)) trees) - go indent (Resource _ _ fun) = go indent (fun undefined) + go indent (Group opts name trees) + -- If we collapse, the name of the group gets prefixed before the final status message after collapsing. + | goCollapse opts = maximum (2*indent + length name : map (go (indent+1)) trees) + -- If we don't collapse, the group name does get printed but without any status message, so it doesn't need to get accounted for in maxlen. + | otherwise = maximum (0 : map (go (indent+1)) trees) + go indent (Resource _ _ _ fun) = go indent (fun undefined) go indent (HP name _) = 2 * indent + length name data Stats = Stats @@ -97,22 +127,21 @@ data Stats = Stats initStats :: Stats initStats = Stats 0 0 -newtype M a = M (StateT Stats IO a) - deriving newtype (Functor, Applicative, Monad, MonadIO) - -modifyStats :: (Stats -> Stats) -> M () -modifyStats f = M (modify f) +modifyStats :: (?stats :: IORef Stats) => (Stats -> Stats) -> IO () +modifyStats f = atomicModifyIORef' ?stats (\s -> (f s, ())) data Options = Options { optsPattern :: String , optsHelp :: Bool , optsHedgehogReplay :: Maybe (H.Skip, H.Seed) , optsHedgehogShrinks :: Maybe Int + , optsParallel :: Bool + , optsVerbose :: Bool } deriving (Show) defaultOptions :: Options -defaultOptions = Options "" False Nothing Nothing +defaultOptions = Options "" False Nothing Nothing False False parseOptions :: [String] -> Options -> Either String Options parseOptions [] opts = pure opts @@ -134,6 +163,8 @@ parseOptions ("--hedgehog-shrinks":arg:args) opts = case readMaybe arg of Just n -> parseOptions args opts { optsHedgehogShrinks = Just n } Nothing -> Left "Invalid argument to '--hedgehog-shrinks'" +parseOptions ("--parallel":args) opts = parseOptions args opts { optsParallel = True } +parseOptions ("--verbose":args) opts = parseOptions args opts { optsVerbose = True } parseOptions (arg:_) _ = Left $ "Unrecognised argument: '" ++ arg ++ "'" printUsage :: IO () @@ -147,7 +178,12 @@ printUsage = do ," test looks like: '^group1/group2/testname$'." ," --hedgehog-replay '{skip} {seed}'" ," Skip to a particular generated Hedgehog test. Should be used" - ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set."] + ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set." + ," --hedgehog-shrinks NUM" + ," Limit the number of shrinking steps." + ," --parallel Run tests in parallel." + ," --verbose Currently only has an effect with --parallel. Also shows OK" + ," and timing for test groups, not only individual tests."] defaultMain :: TestTree -> IO () defaultMain tree = do @@ -165,58 +201,180 @@ runTests options = \tree' -> return (ExitFailure 1) Just tree -> do isterm <- hIsTerminalDevice stdout - let M m = let ?maxlen = computeMaxLen tree - ?istty = isterm - in go 0 id tree starttm <- getCurrentTime - (success, stats) <- runStateT m initStats + statsRef <- newIORef initStats + success <- + let ?stats = statsRef + ?options = options + ?maxlen = computeMaxLen tree + ?istty = isterm + in if optsParallel options + then do nproc <- getNumProcessors + setNumCapabilities nproc + displayConsoleRegions $ + withWorkerPool nproc $ \pool -> do + let ?pool = pool + successVar <- newEmptyMVar + runTreePar Nothing [] [] tree successVar + readMVar successVar + else isJust <$> runTreeSeq 0 [] tree + stats <- readIORef statsRef endtm <- getCurrentTime let ?istty = isterm in printStats stats (diffUTCTime endtm starttm) - return (if isJust success then ExitSuccess else ExitFailure 1) + return (if success then ExitSuccess else ExitFailure 1) + +-- | Returns when all jobs in this tree have been scheduled. When all jobs are +-- done, the outvar is filled with whether all tests in this tree were +-- successful. +-- The mparregion is the parent region to take over, if any. Having a parent +-- region to take over implies that we are currently executing in a worker and +-- can hence run blocking user code directly in the current thread. +runTreePar :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> [Int] -> [String] -> TestTree -> MVar Bool -> IO () +-- TODO: handle collapse somehow? +runTreePar mparregion revidxlist revpath (Group _ name trees) outvar = do + let path = intercalate "/" (reverse (name : revpath)) + -- outputConcurrent $ "! " ++ path ++ ": Started\n" + + -- If we have exactly one child and we're currently running in a worker, we can continue doing so + mparregion2 <- case trees of + [_] -> return mparregion + _ -> do -- If not, we have to close the region (and implicitly relinquish the worker job) + forM_ mparregion closeConsoleRegion + return Nothing + + starttm <- getCurrentTime + suboutvars <- forM (zip trees [0..]) $ \(tree, idx) -> do + suboutvar <- newEmptyMVar + runTreePar mparregion2 (idx : revidxlist) (name : revpath) tree suboutvar + return suboutvar + + -- outputConcurrent $ "! " ++ path ++ ": Scheduled all\n" + + -- If we took over the parent region then this readMVar will resolve + -- immediately and the forkIO would be unnecessary. Meh. + _ <- forkIO $ do + success <- and <$> traverse readMVar suboutvars + endtm <- getCurrentTime + -- outputConcurrent $ "! " ++ path ++ ": Done\n" + if success && optsVerbose ?options + then let text = path ++ ": " ++ ansiGreen ++ "OK" ++ ansiReset ++ " " ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + in outputConcurrent (text ++ "\n") + else return () + putMVar outvar success + + return () + +runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runResource topmparregion 1 toptree topoutvar where - -- If all tests are successful, returns the number of output lines produced - go :: (?maxlen :: Int, ?istty :: Bool) => Int -> (String -> String) -> TestTree -> M (Maybe Int) - go indent path (Group collapse name trees) = do - liftIO $ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout - starttm <- liftIO getCurrentTime - mlns <- fmap (fmap sum . sequence) . forM trees $ - go (indent + 1) (path . (name++) . ('/':)) - endtm <- liftIO getCurrentTime - case mlns of - Just lns | collapse, ?istty -> do - let thislen = 2*indent + length name - liftIO $ putStrLn $ concat (replicate (lns+1) "\x1B[A\x1B[2K") ++ "\x1B[G" ++ - replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ - "\x1B[32mOK\x1B[0m" ++ - prettyDuration False (realToFrac (diffUTCTime endtm starttm)) - return (Just 1) - _ -> return mlns - go indent path (Resource make cleanup fun) = do - value <- liftIO make - success <- go indent path (fun value) - liftIO $ cleanup value - return success - go indent path (HP name (H.Property config test)) = do + runResource + :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> Int -> TestTree -> MVar Bool -> IO () + runResource mparregion depth (Resource inhname make cleanup fun) outvar = do + let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")" + path = intercalate "/" (reverse (pathitem : revpath)) + idxlist = reverse revidxlist + -- outputConcurrent $ "! " ++ path ++ ": R Submitting\n" + submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do + setConsoleRegion makeRegion ('|' : path ++ " [R] making...") + + -- outputConcurrent $ "! " ++ path ++ ": R Making\n" + value <- make -- TODO: catch exceptions + -- outputConcurrent $ "! " ++ path ++ ": R Made\n" + + -- outputConcurrent $ "! " ++ path ++ ": R Running subtree\n" + suboutvar <- newEmptyMVar + runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion + -- outputConcurrent $ "! " ++ path ++ ": R Scheduled subtree\n" + + _ <- forkIO $ do + success <- readMVar suboutvar + -- outputConcurrent $ "! " ++ path ++ ": R Subtree done, scheduling cleanup\n" + poolSubmit ?pool idxlist (Just outvar) $ do + cleanupRegion <- openConsoleRegion Linear + setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...") + -- outputConcurrent $ "! " ++ path ++ ": R Cleaning up\n" + cleanup value -- TODO: catch exceptions + -- outputConcurrent $ "! " ++ path ++ ": R Cleanup done\n" + closeConsoleRegion cleanupRegion + return success + return () + runResource mparregion _ tree outvar = runTreePar mparregion revidxlist revpath tree outvar + +runTreePar mparregion revidxlist revpath (HP name prop) outvar = do + let path = intercalate "/" (reverse (name : revpath)) + idxlist = reverse revidxlist + + submitOrRunIn mparregion idxlist (Just outvar) $ \region -> do + -- outputConcurrent $ "! " ++ path ++ ": Running" + let prefix = path ++ " [T]" + setConsoleRegion region ('|' : prefix) + let progressHandler report = do + str <- H.renderProgress H.EnableColor (Just (fromString "")) report + setConsoleRegion region ('|' : prefix ++ ": " ++ replace '\n' " " str) + (ok, rendered) <- runHP progressHandler revpath name prop + finishConsoleRegion region (path ++ ": " ++ rendered) + return ok + +submitOrRunIn :: (?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> [Int] -> Maybe (MVar a) -> (ConsoleRegion -> IO a) -> IO () +submitOrRunIn Nothing idxlist outvar fun = + poolSubmit ?pool idxlist outvar (openConsoleRegion Linear >>= fun) +submitOrRunIn (Just reg) _idxlist outvar fun = do + result <- fun reg + forM_ outvar $ \mvar -> putMVar mvar result + +-- | If all tests are successful, returns the number of output lines produced +runTreeSeq :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool) + => Int -> [String] -> TestTree -> IO (Maybe Int) +runTreeSeq indent revpath (Group opts name trees) = do + putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout + starttm <- getCurrentTime + mlns <- fmap (fmap sum . sequence) . forM trees $ + runTreeSeq (indent + 1) (name : revpath) + endtm <- getCurrentTime + case mlns of + Just lns | goCollapse opts, ?istty -> do let thislen = 2*indent + length name - let outputPrefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' - when ?istty $ liftIO $ putStr outputPrefix >> hFlush stdout + putStrLn $ concat (replicate (lns+1) (ANSI.cursorUpCode 1 ++ ANSI.clearLineCode)) ++ + ANSI.setCursorColumnCode 0 ++ + replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ + ansiGreen ++ "OK" ++ ansiReset ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + return (Just 1) + _ -> return ((+1) <$> mlns) +runTreeSeq indent path (Resource _ make cleanup fun) = do + value <- make + success <- runTreeSeq indent path (fun value) + cleanup value + return success +runTreeSeq indent path (HP name prop) = do + let thislen = 2*indent + length name + let prefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' + when ?istty $ putStr prefix >> hFlush stdout + (ok, rendered) <- runHP (outputProgress (?maxlen + 2)) path name prop + putStrLn ((if ?istty then ANSI.clearFromCursorToLineEndCode else prefix) ++ rendered) >> hFlush stdout + return (if ok then Just 1 else Nothing) - let (config', seedfun) = applyHedgehogOptions options config - seed <- seedfun +runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int) + => (H.Report H.Progress -> IO ()) + -> [String] + -> String -> H.Property -> IO (Bool, String) +runHP progressPrinter revpath name (H.Property config test) = do + let (config', seedfun) = applyHedgehogOptions ?options config + seed <- seedfun - starttm <- liftIO getCurrentTime - report <- liftIO $ H.checkReport config' 0 seed test (outputProgress (?maxlen + 2)) - endtm <- liftIO getCurrentTime + starttm <- getCurrentTime + report <- H.checkReport config' 0 seed test progressPrinter + endtm <- getCurrentTime - liftIO $ do - when (not ?istty) $ putStr outputPrefix - printResult report (path name) (diffUTCTime endtm starttm) - hFlush stdout + rendered <- renderResult report (intercalate "/" (reverse (name : revpath))) (diffUTCTime endtm starttm) - let ok = H.reportStatus report == H.OK - modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats - , statsTotal = 1 + statsTotal stats } - return (if ok then Just 1 else Nothing) + let ok = H.reportStatus report == H.OK + modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats + , statsTotal = 1 + statsTotal stats } + return (ok, rendered) applyHedgehogOptions :: MonadIO m => Options -> H.PropertyConfig -> (H.PropertyConfig, m H.Seed) applyHedgehogOptions opts config0 = @@ -232,32 +390,65 @@ outputProgress :: (?istty :: Bool) => Int -> H.Report H.Progress -> IO () outputProgress indent report | ?istty = do str <- H.renderProgress H.EnableColor (Just (fromString "")) report - putStr (replace '\n' " " str ++ "\x1B[" ++ show (indent+1) ++ "G") + putStr (replace '\n' " " str ++ ANSI.setCursorColumnCode indent) hFlush stdout | otherwise = return () -printResult :: (?istty :: Bool) => H.Report H.Result -> String -> NominalDiffTime -> IO () -printResult report path timeTaken = do +renderResult :: H.Report H.Result -> String -> NominalDiffTime -> IO String +renderResult report path timeTaken = do str <- H.renderResult H.EnableColor (Just (fromString "")) report case H.reportStatus report of - H.OK -> putStrLn (ansi "\x1B[K" ++ str ++ prettyDuration False (realToFrac timeTaken)) + H.OK -> return (str ++ prettyDuration False (realToFrac timeTaken)) H.Failed failure -> do let H.Report { H.reportTests = count, H.reportDiscards = discards } = report replayInfo = H.skipCompress (H.SkipToShrink count discards (H.failureShrinkPath failure)) ++ " " ++ show (H.reportSeed report) suffix = "\n Flags to reproduce: `-p '" ++ path ++ "' --hedgehog-replay '" ++ replayInfo ++ "'`" - putStrLn (ansi "\x1B[K" ++ str ++ suffix) - _ -> putStrLn (ansi "\x1B[K" ++ str) + return (str ++ suffix) + _ -> return str printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO () printStats stats timeTaken | statsOK stats == statsTotal stats = do - putStrLn $ ansi "\x1B[32m" ++ "All " ++ show (statsTotal stats) ++ - " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" + putStrLn $ ansiGreen ++ "All " ++ show (statsTotal stats) ++ + " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset | otherwise = let nfailed = statsTotal stats - statsOK stats - in putStrLn $ ansi "\x1B[31m" ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ - " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" + in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ + " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + + +newtype WorkerPool k = WorkerPool (TVar (PQ.MinPQueue k (Terminate PoolJob))) +data PoolJob = forall a. PoolJob (IO a) (Maybe (MVar a)) +data Terminate a = Terminate | Value a + deriving (Eq, Ord) -- Terminate sorts before Value + +withWorkerPool :: Ord k => Int -> (WorkerPool k -> IO a) -> IO a +withWorkerPool numWorkers k = do + chan <- newTVarIO PQ.empty + threads <- forM [0..numWorkers-1] (\i -> forkOn i (worker i chan)) + k (WorkerPool chan) `finally` do + replicateM_ numWorkers (atomically $ writeTVar chan PQ.empty) + forM_ threads killThread + where + worker :: Ord k => Int -> TVar (PQ.MinPQueue k (Terminate PoolJob)) -> IO () + worker idx chan = do + job <- atomically $ + readTVar chan >>= \case + PQ.Empty -> retry + (_, j) PQ.:< q -> writeTVar chan q >> return j + case job of + Value (PoolJob action mmvar) -> do + -- outputConcurrent $ "[" ++ show idx ++ "] got job\n" + result <- action + forM_ mmvar $ \mvar -> putMVar mvar result + worker idx chan + Terminate -> return () + +poolSubmit :: Ord k => WorkerPool k -> k -> Maybe (MVar a) -> IO a -> IO () +poolSubmit (WorkerPool chan) key mmvar action = + atomically $ modifyTVar chan $ PQ.insert key (Value (PoolJob action mmvar)) + prettyDuration :: Bool -> Double -> String prettyDuration False x | x < 0.5 = "" @@ -273,3 +464,23 @@ replace x ys = concatMap (\y -> if y == x then ys else [y]) ansi :: (?istty :: Bool) => String -> String ansi | ?istty = id | otherwise = const "" + +ansiRed, ansiGreen, ansiReset :: (?istty :: Bool) => String +ansiRed = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red]) +ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green]) +ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset]) + +-- getTermIsDark :: IO Bool +-- getTermIsDark = do +-- mclr <- ANSI.getLayerColor ANSI.Background +-- case mclr of +-- Nothing -> return True +-- Just (RGB r g b) -> +-- let cvt n = fromIntegral n / fromIntegral (maxBound `asTypeOf` n) +-- in return $ (cvt r + cvt g + cvt b) / 3 < (0.5 :: Double) + +-- ansiRegionBg :: (?istty :: Bool, ?termisdark :: Bool) => String +-- ansiRegionBg +-- | not ?istty = "" +-- | ?termisdark = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.0 0.05 0.1)] +-- | otherwise = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.95 0.95 1.0)] diff --git a/test/Main.hs b/test/Main.hs index f5e4a3c..4bc9082 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -25,6 +25,7 @@ import Test.Framework import Array import AST hiding ((.>)) +import AST.Count (pruneExpr) import AST.Pretty import AST.UnMonoid import CHAD.Top @@ -63,18 +64,18 @@ simplifyIters iters env | Dict <- envKnown env = SimplIters n -> simplifyN n SimplFix -> simplifyFix --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD simplIters env term input = - let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False env input dterm - in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) +-- gradientByCHAD simplIters env term input = +-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term +-- (out, grad) = interpretOpen False env input dterm +-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' simplIters env term input = - second (second (toTanE env input)) $ - gradientByCHAD simplIters env term input +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) +-- gradientByCHAD' simplIters env term input = +-- second (second (toTanE env input)) $ +-- gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 @@ -85,19 +86,19 @@ extendDN STNil () = pure () extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y extendDN (STEither a _) (Left x) = Left <$> extendDN a x extendDN (STEither _ b) (Right y) = Right <$> extendDN b y +extendDN (STLEither _ _) Nothing = pure Nothing +extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x +extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y extendDN (STMaybe _) Nothing = pure Nothing extendDN (STMaybe t) (Just x) = Just <$> extendDN t x extendDN (STArr _ t) arr = traverse (extendDN t) arr extendDN (STScal sty) x = case sty of - STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) - STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) + STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) STI32 -> pure x STI64 -> pure x STBool -> pure x extendDN (STAccum _) _ = error "Accumulators not supported in input program" -extendDN (STLEither _ _) Nothing = pure Nothing -extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x -extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env)) extendDNE SNil SNil = pure SNil @@ -116,6 +117,10 @@ closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x' closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x' closeIshT' _ STEither{} _ _ = False +closeIshT' _ (STLEither _ _) Nothing Nothing = True +closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x' +closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y' +closeIshT' _ STLEither{} _ _ = False closeIshT' _ (STMaybe _) Nothing Nothing = True closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x' closeIshT' _ STMaybe{} _ _ = False @@ -128,10 +133,6 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) closeIshT' h (STScal STF64) x y = closeIsh' h x y closeIshT' _ (STScal STBool) x y = x == y closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" -closeIshT' _ (STLEither _ _) Nothing Nothing = True -closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x' -closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y' -closeIshT' _ STLEither{} _ _ = False closeIshT :: STy t -> Rep t -> Rep t -> Bool closeIshT = closeIshT' 1e-5 @@ -233,19 +234,19 @@ genValue topty tpl = case topty of STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) ,liftV Right <$> genValue b (emptyTpl b)] + STLEither a b -> Gen.frequency [(1, pure (Value Nothing)) + ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a)) + ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of - STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10) + STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10) STI32 -> genInt STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" - STLEither a b -> Gen.frequency [(1, pure (Value Nothing)) - ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a)) - ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))] where genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t) genInt = do @@ -302,10 +303,12 @@ adTestGen name expr envGenerator = exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name + groupSetCollapse $ testGroup name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS - ,adTestGenChad env envGenerator expr exprS primalSfun] + ,testGroup "chad" + [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun + ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]] adTestGenPrimal :: SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R @@ -336,24 +339,31 @@ adTestGenFwd env envGenerator exprS = diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 -adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) +adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree -adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 - dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermChadSUS = simplifyFix $ unMonoid dtermChadS + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 + dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS + dtermSChadSUSP = pruneExpr env dtermSChadSUS in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> - testProperty "chad" $ property $ do + withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> + testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - -- pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) + -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) + let dtermChad20 = simplifyN 20 dtermChad0 + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20)) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20))) + let dtermSChad20 = simplifyN 20 dtermSChad0 + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20)) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -363,35 +373,47 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP - (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) - let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS + (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input) + let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outCompSChadS closeIsh outPrimal + annotate (ppExpr env dtermSChadSUSP) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outSChadSUSP closeIsh outPrimal + diff outCompSChadSUSP closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansCompSChadS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansSChadSUSP closeIshE' tansFwd + diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) @@ -409,7 +431,7 @@ gen_gmm = do vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) - vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vgamma <- Gen.realFloat (Range.constant (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma @@ -435,11 +457,30 @@ gen_neural = do lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx + term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -502,25 +543,32 @@ tests_Compile = testGroup "Compile" ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) (pair 0.0 0.0) $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TMaybe (TPair R R)) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree @@ -556,9 +604,7 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum @@ -566,6 +612,27 @@ tests_AD = testGroup "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + + ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#x ! pair nil #i) $ + 3 * fst_ #p + 2 * snd_ #p + + ,adTest "idx-coprod" $ term_idx_coprod + + ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ + let_ #n (snd_ (shape #arr)) $ + let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $ + if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x)) + (pair (inr (3 * #x)) (exp #x)))) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#b ! pair nil #i) $ + case_ (fst_ #p) + (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p) + (#b :-> #b * 4) + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x @@ -607,6 +674,35 @@ tests_AD = testGroup "AD" ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm + + ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree + + ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ + idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + + ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #innern (snd_ #sh) $ + let_ #n (#innern * snd_ (fst_ #sh)) $ + let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ + -- ensure the input array to EReshape is shared + idx0 $ sum1i $ + build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) + + ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a + + ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a + + ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + let_ #v 2 $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a + + ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a ] main :: IO () |
