aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs267
-rw-r--r--src/AST/Accum.hs127
-rw-r--r--src/AST/Bindings.hs13
-rw-r--r--src/AST/Count.hs16
-rw-r--r--src/AST/Env.hs74
-rw-r--r--src/AST/Pretty.hs186
-rw-r--r--src/AST/Sparse.hs290
-rw-r--r--src/AST/Sparse/Types.hs107
-rw-r--r--src/AST/SplitLets.hs154
-rw-r--r--src/AST/Types.hs162
-rw-r--r--src/AST/UnMonoid.hs255
-rw-r--r--src/AST/Weaken.hs13
-rw-r--r--src/AST/Weaken/Auto.hs2
-rw-r--r--src/Analysis/Identity.hs93
-rw-r--r--src/CHAD.hs1256
-rw-r--r--src/CHAD/Accum.hs52
-rw-r--r--src/CHAD/EnvDescr.hs53
-rw-r--r--src/CHAD/Top.hs63
-rw-r--r--src/CHAD/Types.hs85
-rw-r--r--src/CHAD/Types/ToTan.hs23
-rw-r--r--src/Compile.hs598
-rw-r--r--src/Compile/Exec.hs25
-rw-r--r--src/Data.hs39
-rw-r--r--src/Data/VarMap.hs119
-rw-r--r--src/Example.hs32
-rw-r--r--src/Example/GMM.hs4
-rw-r--r--src/ForwardAD.hs30
-rw-r--r--src/ForwardAD/DualNumbers.hs9
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs2
-rw-r--r--src/Interpreter.hs398
-rw-r--r--src/Interpreter/Rep.hs69
-rw-r--r--src/Language.hs22
-rw-r--r--src/Language/AST.hs11
-rw-r--r--src/Simplify.hs543
-rw-r--r--src/Simplify/TH.hs80
35 files changed, 3733 insertions, 1539 deletions
diff --git a/src/AST.hs b/src/AST.hs
index c8377de..5aab4fc 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -16,13 +16,16 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE FlexibleInstances #-}
module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
+import Data.Functor.Identity
import Data.Kind (Type)
import Array
import AST.Accum
+import AST.Sparse.Types
import AST.Types
import AST.Weaken
import CHAD.Types
@@ -33,11 +36,9 @@ import Data
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
--
--- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the
--- type transformation of CHAD. Indeed, these constructors are created _by_
--- CHAD, and are intended to be eliminated after simplification, so that the
--- input program as well as the output program do not contain these
--- constructors.
+-- All the monoid operations are unsupposed as the input to CHAD, and are
+-- intended to be eliminated after simplification, so that the input program as
+-- well as the output program do not contain these constructors.
-- TODO: ensure this by a "stage" type parameter.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
@@ -87,15 +88,28 @@ data Expr x env t where
-> Expr x env a -> Expr x env b
-> Expr x env t
+ -- fake halfway checkpointing
+ ERecompute :: x t -> Expr x env t -> Expr x env t
+
-- accumulation effect on monoids
- EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t))
- -- TODO: let this contain a OneHotTerm that is shared with EOneHot for uniformity in Simplify
- EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil
+ -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
+ -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
+ -- need to create any zeros.
+ EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ -- The 'Sparse' here is eliminated to dense by UnMonoid.
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)
- EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
- EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t)
+ EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
+ EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
+ EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+
+ -- interface of abstract monoidal types
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
+ ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
@@ -127,6 +141,7 @@ data SOp a t where
OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
deriving instance Show (SOp a t)
opt1 :: SOp a t -> STy a
@@ -147,6 +162,7 @@ opt1 = \case
OExp t -> STScal t
OLog t -> STScal t
OIDiv t -> STPair (STScal t) (STScal t)
+ OMod t -> STPair (STScal t) (STScal t)
opt2 :: SOp a t -> STy t
opt2 = \case
@@ -166,6 +182,7 @@ opt2 = \case
OExp t -> STScal t
OLog t -> STScal t
OIDiv t -> STScal t
+ OMod t -> STScal t
typeOf :: Expr x env t -> STy t
typeOf = \case
@@ -182,6 +199,10 @@ typeOf = \case
ENothing _ t -> STMaybe t
EJust _ e -> STMaybe (typeOf e)
EMaybe _ e _ _ -> typeOf e
+ ELNil _ t1 t2 -> STLEither t1 t2
+ ELInl _ t2 e -> STLEither (typeOf e) t2
+ ELInr _ t1 e -> STLEither t1 (typeOf e)
+ ELCase _ _ a _ _ -> typeOf a
EConstArr _ n t _ -> STArr n (STScal t)
EBuild _ n _ e -> STArr n (typeOf e)
@@ -200,13 +221,15 @@ typeOf = \case
EOp _ op _ -> opt2 op
ECustom _ _ _ _ e _ _ _ _ -> typeOf e
+ ERecompute _ e -> typeOf e
EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ _ _ -> STNil
+ EAccum _ _ _ _ _ _ _ -> STNil
- EZero _ t -> d2 t
- EPlus _ t _ _ -> d2 t
- EOneHot _ t _ _ _ -> d2 t
+ EZero _ t _ -> fromSMTy t
+ EDeepZero _ t _ -> fromSMTy t
+ EPlus _ t _ _ -> fromSMTy t
+ EOneHot _ t _ _ _ -> fromSMTy t
EError _ t _ -> t
@@ -224,6 +247,10 @@ extOf = \case
ENothing x _ -> x
EJust x _ -> x
EMaybe x _ _ _ -> x
+ ELNil x _ _ -> x
+ ELInl x _ _ -> x
+ ELInr x _ _ -> x
+ ELCase x _ _ _ _ -> x
EConstArr x _ _ _ -> x
EBuild x _ _ _ -> x
EFold1Inner x _ _ _ _ -> x
@@ -239,16 +266,70 @@ extOf = \case
EShape x _ -> x
EOp x _ _ -> x
ECustom x _ _ _ _ _ _ _ _ -> x
+ ERecompute x _ -> x
EWith x _ _ _ -> x
- EAccum x _ _ _ _ _ -> x
- EZero x _ -> x
+ EAccum x _ _ _ _ _ _ -> x
+ EZero x _ _ -> x
+ EDeepZero x _ _ -> x
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
-subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
-subst1 repl = subst $ \x t -> \case IZ -> repl
- IS i -> EVar x t i
+mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt f = runIdentity . travExt (Identity . f)
+
+{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
+travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt f = \case
+ EVar x t i -> EVar <$> f x <*> pure t <*> pure i
+ ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
+ EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b
+ EFst x e -> EFst <$> f x <*> travExt f e
+ ESnd x e -> ESnd <$> f x <*> travExt f e
+ ENil x -> ENil <$> f x
+ EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e
+ EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e
+ ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b
+ ENothing x t -> ENothing <$> f x <*> pure t
+ EJust x e -> EJust <$> f x <*> travExt f e
+ EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e
+ ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2
+ ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e
+ ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e
+ ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
+ EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
+ EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
+ EUnit x e -> EUnit <$> f x <*> travExt f e
+ EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
+ EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
+ EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EConst x t v -> EConst <$> f x <*> pure t <*> pure v
+ EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
+ EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
+ EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es
+ EShape x e -> EShape <$> f x <*> travExt f e
+ EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
+ ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
+ ERecompute x e -> ERecompute <$> f x <*> travExt f e
+ EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
+ EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3
+ EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
+ EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e
+ EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
+ EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
+ EError x t s -> EError <$> f x <*> pure t <*> pure s
+
+substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+substInline repl =
+ subst $ \x t -> \case IZ -> repl
+ IS i -> EVar x t i
+
+subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t
+subst0 repl =
+ subst $ \_ t -> \case IZ -> repl
+ IS i -> EVar ext t (IS i)
subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
-> Expr x env t -> Expr x env' t
@@ -271,6 +352,10 @@ subst' f w = \case
ENothing x t -> ENothing x t
EJust x e -> EJust x (subst' f w e)
EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (subst' f w e)
+ ELInr x t e -> ELInr x t (subst' f w e)
+ ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
@@ -286,11 +371,13 @@ subst' f w = \case
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
+ ERecompute x e -> ERecompute x (subst' f w e)
EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
- EZero x t -> EZero x t
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3)
+ EZero x t e -> EZero x t (subst' f w e)
+ EDeepZero x t e -> EDeepZero x t (subst' f w e)
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -302,15 +389,6 @@ subst' f w = \case
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-slistIdx :: SList f list -> Idx list t -> f t
-slistIdx (SCons x _) IZ = x
-slistIdx (SCons _ list) (IS i) = slistIdx list i
-slistIdx SNil i = case i of {}
-
-idx2int :: Idx env t -> Int
-idx2int IZ = 0
-idx2int (IS n) = 1 + idx2int n
-
class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64
@@ -322,10 +400,19 @@ class KnownTy t where knownTy :: STy t
instance KnownTy TNil where knownTy = STNil
instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy
+instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+
+class KnownMTy t where knownMTy :: SMTy t
+instance KnownMTy TNil where knownMTy = SMTNil
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
+instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
+instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
+instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
@@ -335,10 +422,19 @@ styKnown :: STy t -> Dict (KnownTy t)
styKnown STNil = Dict
styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
-styKnown (STAccum t) | Dict <- styKnown t = Dict
+styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+
+smtyKnown :: SMTy t -> Dict (KnownMTy t)
+smtyKnown SMTNil = Dict
+smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
+smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
+smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
sscaltyKnown STI32 = Dict
@@ -375,27 +471,30 @@ eidxEq (SS n) a b
(eidxEq n (EFst ext (EVar ext ty (IS IZ)))
(EFst ext (EVar ext ty IZ)))
-emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
-emap f arr =
- let STArr n t = typeOf arr
- in ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
-
-ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
-ezipWith f arr1 arr2 =
- let STArr n t1 = typeOf arr1
- STArr _ t2 = typeOf arr2
- in ELet ext arr1 $
- ELet ext (weakenExpr WSink arr2) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
- weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
+emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
+emap f arr
+ | STArr n t <- typeOf arr
+ , Dict <- styKnown t
+ = ELet ext arr $
+ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) f
+
+ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
+ezipWith f arr1 arr2
+ | STArr n t1 <- typeOf arr1
+ , STArr _ t2 <- typeOf arr2
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELet ext arr1 $
+ ELet ext (weakenExpr WSink arr2) $
+ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
+ weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
ezip arr1 arr2 =
@@ -416,3 +515,61 @@ eshapeEmpty (SS n) e =
(EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
(EConst ext STI64 0)))
(eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
+-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
+-- ezeroD2 t ezi = EZero ext (d2M t) ezi
+
+-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
+-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
+
+-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
+-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev
+
+eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
+eunPair (EPair _ e1 e2) k = k WId e1 e2
+eunPair e k =
+ elet e $
+ k WSink
+ (EFst ext (evar IZ))
+ (ESnd ext (evar IZ))
+
+efst :: Ex env (TPair a b) -> Ex env a
+efst (EPair _ e1 _) = e1
+efst e = EFst ext e
+
+esnd :: Ex env (TPair a b) -> Ex env b
+esnd (EPair _ _ e2) = e2
+esnd e = ESnd ext e
+
+elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
+elet rhs body
+ | Dict <- styKnown (typeOf rhs)
+ = ELet ext rhs body
+
+emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
+emaybe e a b
+ | STMaybe t <- typeOf e
+ , Dict <- styKnown t
+ = EMaybe ext a b e
+
+elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c
+elcase e a b c
+ | STLEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELCase ext e a b c
+
+evar :: KnownTy a => Idx env a -> Ex env a
+evar = EVar ext knownTy
+
+makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
+ where
+ -- invariant: expression argument is duplicable
+ go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+ go SMTNil _ = ENil ext
+ go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
+ go SMTLEither{} _ = ENil ext
+ go SMTMaybe{} _ = ENil ext
+ go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
+ go SMTScal{} _ = ENil ext
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 67c5de7..988a450 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,8 +1,8 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
@@ -26,35 +26,112 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPHere :: SAcPrj APHere a a
SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
- SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
- SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
+ SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
- -- TODO: This SNat is rather useless, you always have an STy around too
- SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
+ SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
-- TODO:
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type family AcIdx p t where
- AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = AcIdx p a
- AcIdx (APSnd p) (TPair a b) = AcIdx p b
- AcIdx (APLeft p) (TEither a b) = AcIdx p a
- AcIdx (APRight p) (TEither a b) = AcIdx p b
- AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) =
- -- ((index, array shape), recursive info)
- TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx)))
- (AcIdx p a)
- -- AcIdx (APArrSlice m) (TArr n a) =
+type data AIDense = AID | AIS
+
+data SAIDense d where
+ SAID :: SAIDense AID
+ SAIS :: SAIDense AIS
+deriving instance Show (SAIDense d)
+
+type family AcIdx d p t where
+ AcIdx d APHere t = TNil
+ AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
+ AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
+ AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
+ AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
+ AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
+ AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
+ AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
+ AcIdx AID (APArrIdx p) (TArr n a) =
+ -- (index, recursive info)
+ TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
+ AcIdx AIS (APArrIdx p) (TArr n a) =
+ -- ((index, shape info), recursive info)
+ TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
+ (AcIdx AIS p a)
+ -- AcIdx AID (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AIS (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
-acPrjTy :: SAcPrj p a b -> STy a -> STy b
+type AcIdxD p t = AcIdx AID p t
+type AcIdxS p t = AcIdx AIS p t
+
+acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
-acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t
-acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t
-acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t
-acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t
-acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t
-acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t
+acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
+acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
+acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
+acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
+acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
+acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
+
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Should be constructable from a D1.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+
+-- -- | Additional info needed for accumulation. This is empty unless there is
+-- -- sparsity in the monoid.
+-- type family AccumInfo t where
+-- AccumInfo TNil = TNil
+-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
+-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
+-- AccumInfo (TArr n t) = TArr n (AccumInfo t)
+-- AccumInfo (TScal t) = TNil
+
+-- type family PrimalInfo t where
+-- PrimalInfo TNil = TNil
+-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
+-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
+-- PrimalInfo (TScal t) = TNil
+
+-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
+-- tPrimalInfo SMTNil = STNil
+-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
+-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
+-- tPrimalInfo (SMTScal _) = STNil
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 2e63b42..2310f4b 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -16,6 +16,7 @@
module AST.Bindings where
import AST
+import AST.Env
import Data
import Lemmas
@@ -45,7 +46,7 @@ weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
weakenOver SNil w = w
weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
-sinkWithBindings :: Bindings f env binds -> env' :> Append binds env'
+sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env'
sinkWithBindings BTop = WId
sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
@@ -62,3 +63,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
+collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
+collectBindings = \env -> fst . go env WId
+ where
+ go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
+ go _ _ SETop = (BTop, WId)
+ go (ty `SCons` env) w (SEYesR sub) =
+ let (bs, w') = go env (WPop w) sub
+ in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
+ go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index dc8ec72..ca4d7ab 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -113,6 +113,10 @@ occCountGeneral onehot unpush alter many = go WId
ENothing _ _ -> mempty
EJust _ e -> re e
EMaybe _ a b e -> re a <> re1 b <> re e
+ ELNil _ _ _ -> mempty
+ ELInl _ _ e -> re e
+ ELInr _ _ e -> re e
+ ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)
EConstArr{} -> mempty
EBuild _ _ a b -> re a <> many (re1 b)
EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
@@ -128,9 +132,11 @@ occCountGeneral onehot unpush alter many = go WId
EShape _ e -> re e
EOp _ _ e -> re e
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
+ ERecompute _ e -> re e
EWith _ _ a b -> re a <> re1 b
- EAccum _ _ _ a b e -> re a <> re b <> re e
- EZero _ _ -> mempty
+ EAccum _ _ _ a _ b e -> re a <> re b <> re e
+ EZero _ _ e -> re e
+ EDeepZero _ _ e -> re e
EPlus _ _ a b -> re a <> re b
EOneHot _ _ _ a b -> re a <> re b
EError{} -> mempty
@@ -149,7 +155,7 @@ deleteUnused (_ `SCons` env) OccEnd k =
deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
deleteUnused env occenv $ \sub ->
case count of Zero -> k (SENo sub)
- _ -> k (SEYes sub)
+ _ -> k (SEYesR sub)
unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
@@ -158,7 +164,7 @@ unsafeWeakenWithSubenv = \sub ->
Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
where
sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkViaSubenv IZ (SEYes _) = Just IZ
+ sinkViaSubenv IZ (SEYesR _) = Just IZ
sinkViaSubenv IZ (SENo _) = Nothing
- sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
+ sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
index 4f34166..422f0f7 100644
--- a/src/AST/Env.hs
+++ b/src/AST/Env.hs
@@ -1,59 +1,85 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Env where
+import Data.Type.Equality
+
+import AST.Sparse
import AST.Weaken
+import CHAD.Types
import Data
-- | @env'@ is a subset of @env@: each element of @env@ is either included in
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv env env' where
- SETop :: Subenv '[] '[]
- SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env'
-deriving instance Show (Subenv env env')
+data Subenv' s env env' where
+ SETop :: Subenv' s '[] '[]
+ SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env')
+ SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env'
+deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env')
+
+type Subenv = Subenv' (:~:)
+type SubenvS = Subenv' Sparse
+
+pattern SEYesR :: forall tenv tenv'. ()
+ => forall t env env'. (tenv ~ t : env, tenv' ~ t : env')
+ => Subenv env env' -> Subenv tenv tenv'
+pattern SEYesR s = SEYes Refl s
-subList :: SList f env -> Subenv env env' -> SList f env'
+{-# COMPLETE SETop, SEYesR, SENo #-}
+
+subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'
subList SNil SETop = SNil
-subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
+subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)
subList (SCons _ xs) (SENo sub) = subList xs sub
-subenvAll :: SList f env -> Subenv env env
+subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env
subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
+subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env)
-subenvNone :: SList f env -> Subenv env '[]
+subenvNone :: SList f env -> Subenv' s env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
-subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
-subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
-subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
-subenvOnehot SNil i = case i of {}
+subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t']
+subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
+subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
+subenvOnehot SNil i _ = case i of {}
-subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3
+subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
subenvCompose SETop SETop = SETop
-subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)
-subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
+subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
-subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1')
+subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')
subenvConcat sub1 SETop = sub1
-subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2)
+subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)
subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
-sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0
+sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0
sinkWithSubenv SETop = WId
-sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub
+sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub
sinkWithSubenv (SENo sub) = sinkWithSubenv sub
-wUndoSubenv :: Subenv env env' -> env' :> env
+wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env
wUndoSubenv SETop = WId
-wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)
wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
+
+subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env'
+subenvMap _ SNil SETop = SETop
+subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub)
+subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub)
+
+subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env')
+subenvD2E SETop = SETop
+subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub)
+subenvD2E (SENo sub) = SENo (subenvD2E sub)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 604133b..fef9686 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -7,11 +7,12 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where
+module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
import Control.Monad (ap)
-import Data.List (intersperse)
+import Data.List (intersperse, intercalate)
import Data.Functor.Const
+import qualified Data.Functor.Product as Product
import Data.String (fromString)
import Prettyprinter
import Prettyprinter.Render.String
@@ -24,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -49,12 +51,20 @@ instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j)
genId :: M Int
genId = M (\i -> (i, i + 1))
+nameBaseForType :: STy t -> String
+nameBaseForType STNil = "nil"
+nameBaseForType (STPair{}) = "p"
+nameBaseForType (STEither{}) = "e"
+nameBaseForType (STMaybe{}) = "m"
+nameBaseForType (STScal STI32) = "n"
+nameBaseForType (STScal STI64) = "n"
+nameBaseForType (STArr{}) = "a"
+nameBaseForType (STAccum{}) = "ac"
+nameBaseForType _ = "x"
+
genName' :: String -> M String
genName' prefix = (prefix ++) . show <$> genId
-genName :: M String
-genName = genName' "x"
-
genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn' prefix ty idx ex
| occCount idx ex == mempty = case ty of STNil -> return "()"
@@ -62,19 +72,27 @@ genNameIfUsedIn' prefix ty idx ex
| otherwise = genName' prefix
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
-genNameIfUsedIn = genNameIfUsedIn' "x"
+genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
pprintExpr = putStrLn . ppExpr knownEnv
-ppExpr :: PrettyX x => SList f env -> Expr x env t -> String
-ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1)
+ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
+ppExpr senv e = render $ fst . flip runM 1 $ do
+ val <- mkVal senv
+ e' <- ppExpr' 0 val e
+ let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
+ return $ group $ flatAlt
+ (hang 2 $
+ ppString lam
+ <> hardline <> e')
+ (ppString lam <+> e')
where
mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
mkVal (SCons _ v) = do
val <- mkVal v
- name <- genName
+ name <- genName' "arg"
return (Const name `SCons` val)
ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
@@ -128,12 +146,45 @@ ppExpr' d val expr = case expr of
EMaybe _ a b e -> do
let STMaybe t = typeOf e
- a' <- ppExpr' 11 val a
+ e' <- ppExpr' 0 val e
+ a' <- ppExpr' 0 val a
name <- genNameIfUsedIn t IZ b
b' <- ppExpr' 0 (Const name `SCons` val) b
+ return $ ppParen (d > 0) $
+ align $
+ group (flatAlt
+ (annotate AKey (ppString "case") <> ppX expr <+> e'
+ <> hardline <> annotate AKey (ppString "of"))
+ (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")))
+ <> hardline
+ <> indent 2
+ (ppString "Nothing" <+> ppString "->" <+> a'
+ <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b')
+
+ ELNil _ _ _ -> return (ppString "LNil")
+
+ ELInl _ _ e -> do
e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $
- ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e']
+ return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e'
+
+ ELInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e'
+
+ ELCase _ e a b c -> do
+ e' <- ppExpr' 0 val e
+ let STLEither t1 t2 = typeOf e
+ a' <- ppExpr' 11 val a
+ name1 <- genNameIfUsedIn t1 IZ b
+ b' <- ppExpr' 0 (Const name1 `SCons` val) b
+ name2 <- genNameIfUsedIn t2 IZ c
+ c' <- ppExpr' 0 (Const name2 `SCons` val) c
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "LNil" <+> ppString "->" <+> a'
+ <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b'
+ <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c'
EConstArr _ _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -142,13 +193,14 @@ ppExpr' d val expr = case expr of
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
e' <- ppExpr' 0 (Const name `SCons` val) b
+ let primName = ppString ("build" ++ intSubscript (fromSNat n))
return $ ppParen (d > 0) $
group $ flatAlt
(hang 2 $
- annotate AHighlight (ppString "build") <> ppX expr <+> a'
+ annotate AHighlight primName <> ppX expr <+> a'
<+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
<> hardline <> e')
- (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e'])
+ (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -237,6 +289,10 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
@@ -249,27 +305,35 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ _ prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
+ [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
- EZero _ t -> return $ ppParen (d > 0) $
- annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t
+ EZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
+ EDeepZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
- EPlus _ _ a b -> do
+ EPlus _ t a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
+ ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b']
- EOneHot _ _ prj a b -> do
+ EOneHot _ t prj a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b']
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
@@ -302,14 +366,24 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
-ppAcPrj :: SAcPrj p a b -> String
-ppAcPrj SAPHere = "@"
-ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)"
-ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)"
-ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj
-ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n)
+ppAcPrj :: SMTy a -> SAcPrj p a b -> String
+ppAcPrj _ SAPHere = "."
+ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
+ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)"
+ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
+ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
+
+ppSparse :: SMTy a -> Sparse a b -> String
+ppSparse t sp | Just Refl <- isDense t sp = "D"
+ppSparse _ SpAbsent = "A"
+ppSparse t (SpSparse s) = "S" ++ ppSparse t s
+ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
+ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
+ppSparse (SMTScal _) SpScal = "."
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -334,30 +408,42 @@ operator ORecip{} = (Prefix, "recip")
operator OExp{} = (Prefix, "exp")
operator OLog{} = (Prefix, "log")
operator OIDiv{} = (Infix, "`div`")
+operator OMod{} = (Infix, "`mod`")
ppSTy :: Int -> STy t -> String
-ppSTy d ty = ppTy d (unSTy ty)
+ppSTy d ty = render $ ppSTy' d ty
ppSTy' :: Int -> STy t -> Doc q
-ppSTy' d ty = ppTy' d (unSTy ty)
-
-ppTy :: Int -> Ty -> String
-ppTy d ty = render $ ppTy' d ty
-
-ppTy' :: Int -> Ty -> Doc q
-ppTy' _ TNil = ppString "1"
-ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
-ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
-ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
-ppTy' d (TArr n t) = ppParen (d > 10) $
- ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t
-ppTy' _ (TScal sty) = ppString $ case sty of
- TI32 -> "i32"
- TI64 -> "i64"
- TF32 -> "f32"
- TF64 -> "f64"
- TBool -> "bool"
-ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
+ppSTy' _ STNil = ppString "1"
+ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
+ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
+ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
+ppSTy' d (STArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
+ppSTy' _ (STScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
+ STBool -> "bool"
+ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+
+ppSMTy :: Int -> SMTy t -> String
+ppSMTy d ty = render $ ppSMTy' d ty
+
+ppSMTy' :: Int -> SMTy t -> Doc q
+ppSMTy' _ SMTNil = ppString "1"
+ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b
+ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b
+ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t
+ppSMTy' d (SMTArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t
+ppSMTy' _ (SMTScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
ppString :: String -> Doc x
ppString = fromString
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
new file mode 100644
index 0000000..93258b7
--- /dev/null
+++ b/src/AST/Sparse.hs
@@ -0,0 +1,290 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
+module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
+
+import Data.Type.Equality
+
+import AST
+import AST.Sparse.Types
+import Data (SBool(..))
+
+
+sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
+sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
+sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
+sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
+sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
+ eunPair e1 $ \w1 e1a e1b ->
+ eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
+ EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
+ (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
+sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
+ elet e2 $
+ elcase (weakenExpr WSink e1)
+ (evar IZ)
+ (elcase (evar (IS IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
+ (elcase (evar (IS IZ))
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
+ elet e2 $
+ emaybe (weakenExpr WSink e1)
+ (evar IZ)
+ (emaybe (evar (IS IZ))
+ (EJust ext (evar IZ))
+ (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
+sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+
+
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
+
+
+data Injection sp a b where
+ -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
+ -- 'sparsePlusS' can provide injections even if the caller doesn't require
+ -- them. This simplifies the sparsePlusS code.
+ Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
+ Noinj :: Injection False a b
+
+withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
+withInj (Inj f) k = Inj (k f)
+withInj Noinj _ = Noinj
+
+withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
+ -> ((forall e. Ex e a1 -> Ex e b1)
+ -> (forall e. Ex e a2 -> Ex e b2)
+ -> (forall e'. Ex e' a' -> Ex e' b'))
+ -> Injection sp a' b'
+withInj2 (Inj f) (Inj g) k = Inj (k f g)
+withInj2 Noinj _ _ = Noinj
+withInj2 _ Noinj _ = Noinj
+
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
+-- | This function produces quadratically-sized code in the presence of nested
+-- dynamic sparsity. TODO can this be improved?
+sparsePlusS
+ :: SBool inj1 -> SBool inj2
+ -> SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
+ -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+-- nil override (but don't destroy effects!)
+sparsePlusS _ _ SMTNil _ _ k =
+ k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
+
+-- simplifications
+sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
+ sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
+sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
+ sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
+
+sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
+ let ta = applySparse sp1 (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
+ let tb = applySparse sp2 (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
+ let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
+ let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
+ let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
+ let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
+sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
+
+-- TODO: sparse of Just is just Maybe
+
+-- dense plus
+sparsePlusS _ _ t sp1 sp2 k
+ | Just Refl <- isDense t sp1
+ , Just Refl <- isDense t sp2
+ = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
+
+-- handle absents
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
+
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
+
+-- double sparse yields sparse
+sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- single sparse can yield non-sparse if the other argument is always present
+sparsePlusS SF _ t (SpSparse sp1) sp2 k =
+ sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
+ k sp3 Noinj (Inj inj2)
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (inj2 (evar IZ))
+ (plus (evar IZ) (evar (IS IZ))))
+sparsePlusS ST _ t (SpSparse sp1) sp2 k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> EJust ext (inj2 b))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (EJust ext (inj2 (evar IZ)))
+ (EJust ext (plus (evar IZ) (evar (IS IZ)))))
+sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
+ sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
+ k sp3 inj2 inj1 (flip plus)
+
+-- products
+sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
+ sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
+ sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
+ k (SpPair sp3a sp3b)
+ (withInj2 minj13a minj13b $ \inj13a inj13b ->
+ \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
+ (withInj2 minj23a minj23b $ \inj23a inj23b ->
+ \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
+ (\x1 x2 ->
+ eunPair x1 $ \w1 x1a x1b ->
+ eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
+ EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
+
+-- coproducts
+sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
+ sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
+ sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
+ let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
+ inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
+ inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
+ in
+ k (SpLEither sp3a sp3b)
+ (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
+ (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
+ (\x1 x2 ->
+ elet x2 $
+ elcase (weakenExpr WSink x1)
+ (elcase (evar IZ)
+ nil
+ (inl (inj23a (evar IZ)))
+ (inr (inj23b (evar IZ))))
+ (elcase (evar (IS IZ))
+ (inl (inj13a (evar IZ)))
+ (inl (plusa (evar (IS IZ)) (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
+ (elcase (evar (IS IZ))
+ (inr (inj13b (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
+ (inr (plusb (evar (IS IZ)) (evar IZ)))))
+
+-- maybe
+sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpMaybe sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- dense array cotangents simply recurse
+sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
+ k (SpArr sp3)
+ (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
+ (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+ (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
+ (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+
+-- scalars
+sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs
new file mode 100644
index 0000000..10cac4e
--- /dev/null
+++ b/src/AST/Sparse/Types.hs
@@ -0,0 +1,107 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module AST.Sparse.Types where
+
+import AST.Types
+
+import Data.Kind (Type, Constraint)
+import Data.Type.Equality
+
+
+data Sparse t t' where
+ SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
+ SpAbsent :: Sparse t TNil
+
+ SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
+ SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
+ SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
+ SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpScal :: Sparse (TScal t) (TScal t)
+deriving instance Show (Sparse t t')
+
+class ApplySparse f where
+ applySparse :: Sparse t t' -> f t -> f t'
+
+instance ApplySparse STy where
+ applySparse (SpSparse s) t = STMaybe (applySparse s t)
+ applySparse SpAbsent _ = STNil
+ applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+ applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse SpScal t = t
+
+instance ApplySparse SMTy where
+ applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
+ applySparse SpAbsent _ = SMTNil
+ applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
+ applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse SpScal t = t
+
+
+class IsSubType s where
+ type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
+ subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
+ subtTrans :: s a b -> s b c -> s a c
+ subtFull :: IsSubTypeSubject s f => f t -> s t t
+
+instance IsSubType (:~:) where
+ type IsSubTypeSubject (:~:) f = ()
+ subtApply = gcastWith
+ subtTrans = trans
+ subtFull _ = Refl
+
+instance IsSubType Sparse where
+ type IsSubTypeSubject Sparse f = f ~ SMTy
+ subtApply = applySparse
+
+ subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
+ subtTrans _ SpAbsent = SpAbsent
+ subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
+ subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
+ subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
+ subtTrans SpScal SpScal = SpScal
+
+ subtFull = spDense
+
+spDense :: SMTy t -> Sparse t t
+spDense SMTNil = SpAbsent
+spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
+spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
+spDense (SMTMaybe t) = SpMaybe (spDense t)
+spDense (SMTArr _ t) = SpArr (spDense t)
+spDense (SMTScal _) = SpScal
+
+isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
+isDense SMTNil SpAbsent = Just Refl
+isDense _ SpSparse{} = Nothing
+isDense _ SpAbsent = Nothing
+isDense (SMTPair t1 t2) (SpPair s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTLEither t1 t2) (SpLEither s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTMaybe t) (SpMaybe s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTArr _ t) (SpArr s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTScal _) SpScal = Just Refl
+
+isAbsent :: Sparse t t' -> Bool
+isAbsent (SpSparse s) = isAbsent s
+isAbsent SpAbsent = True
+isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpMaybe s) = isAbsent s
+isAbsent (SpArr s) = isAbsent s
+isAbsent SpScal = False
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
new file mode 100644
index 0000000..dcaf82f
--- /dev/null
+++ b/src/AST/SplitLets.hs
@@ -0,0 +1,154 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module AST.SplitLets (splitLets) where
+
+import Data.Type.Equality
+
+import AST
+import AST.Bindings
+import Lemmas
+
+
+splitLets :: Ex env t -> Ex env t
+splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
+
+splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
+splitLets' = \sub -> \case
+ EVar _ t i -> sub t i WId
+ ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ECase x e a b ->
+ let STEither t1 t2 = typeOf e
+ in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
+ EMaybe x a b e ->
+ let STMaybe t1 = typeOf e
+ in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
+ ELCase x e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
+ EFold1Inner x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+
+ EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
+ EFst x e -> EFst x (splitLets' sub e)
+ ESnd x e -> ESnd x (splitLets' sub e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (splitLets' sub e)
+ EInr x t e -> EInr x t (splitLets' sub e)
+ ENothing x t -> ENothing x t
+ EJust x e -> EJust x (splitLets' sub e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (splitLets' sub e)
+ ELInr x t e -> ELInr x t (splitLets' sub e)
+ EConstArr x n t a -> EConstArr x n t a
+ EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
+ ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
+ EUnit x e -> EUnit x (splitLets' sub e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
+ EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
+ EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EConst x t v -> EConst x t v
+ EIdx0 x e -> EIdx0 x (splitLets' sub e)
+ EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
+ EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es)
+ EShape x e -> EShape x (splitLets' sub e)
+ EOp x op e -> EOp x op (splitLets' sub e)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ ERecompute x e -> ERecompute x (splitLets' sub e)
+ EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
+ EZero x t ezi -> EZero x t (splitLets' sub ezi)
+ EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)
+ EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
+ EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
+ EError x t s -> EError x t s
+ where
+ sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
+ sinkF _ t IZ w = EVar ext t (w @> IZ)
+ sinkF f t (IS i) w = f t i (w .> WSink)
+
+ split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t
+ split1 sub (tbind :: STy bind) body =
+ let (ptrs, bs) = split tbind
+ in letBinds bs $
+ splitLets' (\cases _ IZ w -> subPointers ptrs w
+ t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w)))
+ body
+
+ split2 :: forall bind1 bind2 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
+ split2 sub tbind1 tbind2 body =
+ let (ptrs1', bs1') = split @env' tbind1
+ bs1 = fst (weakenBindings weakenExpr WSink bs1')
+ (ptrs2, bs2) = split @(bind1 : env') tbind2
+ in letBinds bs1 $
+ letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
+ _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env')))
+ t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
+ body
+
+type family Split t where
+ Split (TPair a b) = SplitRec (TPair a b)
+ Split _ = '[]
+
+type family SplitRec t where
+ SplitRec TNil = '[]
+ SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a)
+ SplitRec t = '[t]
+
+data Pointers env t where
+ Point :: STy t -> Idx env t -> Pointers env t
+ PNil :: Pointers env TNil
+ PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b)
+ PWeak :: env' :> env -> Pointers env' t -> Pointers env t
+
+subPointers :: Pointers env t -> env :> env' -> Ex env' t
+subPointers (Point t i) w = EVar ext t (w @> i)
+subPointers PNil _ = ENil ext
+subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w)
+subPointers (PWeak w' p) w = subPointers p (w .> w')
+
+split :: forall env t. STy t
+ -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t))
+split typ = case typ of
+ STPair{} -> splitRec (EVar ext typ IZ) typ
+ STNil -> other
+ STEither{} -> other
+ STLEither{} -> other
+ STMaybe{} -> other
+ STArr{} -> other
+ STScal{} -> other
+ STAccum{} -> other
+ where
+ other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
+ other = (Point typ IZ, BTop)
+
+splitRec :: forall env t. Ex env t -> STy t
+ -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t))
+splitRec rhs typ = case typ of
+ STNil -> (PNil, BTop)
+ STPair (a :: STy a) (b :: STy b)
+ | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env ->
+ let (p1, bs1) = splitRec (EFst ext rhs) a
+ (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
+ in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
+ STEither{} -> other
+ STLEither{} -> other
+ STMaybe{} -> other
+ STArr{} -> other
+ STScal{} -> other
+ STAccum{} -> other
+ where
+ other :: (Pointers (t : env) t, Bindings Ex env '[t])
+ other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index 217b2f5..42bfb92 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -1,64 +1,110 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module AST.Types where
import Data.Int (Int32, Int64)
+import Data.GADT.Compare
import Data.GADT.Show
import Data.Kind (Type)
-import Data.Some
import Data.Type.Equality
import Data
-data Ty
+type data Ty
= TNil
| TPair Ty Ty
| TEither Ty Ty
+ | TLEither Ty Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TAccum Ty -- ^ the accumulator contains D2 of this type
- deriving (Show, Eq, Ord)
+ | TAccum Ty -- ^ contained type must be a monoid type
-data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
- deriving (Show, Eq, Ord)
+type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
type STy :: Ty -> Type
data STy t where
STNil :: STy TNil
STPair :: STy a -> STy b -> STy (TPair a b)
STEither :: STy a -> STy b -> STy (TEither a b)
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
STMaybe :: STy a -> STy (TMaybe a)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
- STAccum :: STy t -> STy (TAccum t)
+ STAccum :: SMTy t -> STy (TAccum t)
deriving instance Show (STy t)
-instance TestEquality STy where
- testEquality STNil STNil = Just Refl
- testEquality STNil _ = Nothing
- testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality STPair{} _ = Nothing
- testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality STEither{} _ = Nothing
- testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
- testEquality STMaybe{} _ = Nothing
- testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality STArr{} _ = Nothing
- testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
- testEquality STScal{} _ = Nothing
- testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
- testEquality STAccum{} _ = Nothing
-
+instance GCompare STy where
+ gcompare = \cases
+ STNil STNil -> GEQ
+ STNil _ -> GLT ; _ STNil -> GGT
+ (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STPair{} _ -> GLT ; _ STPair{} -> GGT
+ (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STLEither{} _ -> GLT ; _ STLEither{} -> GGT
+ (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
+ STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
+ (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ STArr{} _ -> GLT ; _ STArr{} -> GGT
+ (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
+ STScal{} _ -> GLT ; _ STScal{} -> GGT
+ (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+
+instance TestEquality STy where testEquality = geq
+instance GEq STy where geq = defaultGeq
instance GShow STy where gshowsPrec = defaultGshowsPrec
+-- | Monoid types
+type SMTy :: Ty -> Type
+data SMTy t where
+ SMTNil :: SMTy TNil
+ SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b)
+ SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b)
+ SMTMaybe :: SMTy a -> SMTy (TMaybe a)
+ SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
+ SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+deriving instance Show (SMTy t)
+
+instance GCompare SMTy where
+ gcompare = \cases
+ SMTNil SMTNil -> GEQ
+ SMTNil _ -> GLT ; _ SMTNil -> GGT
+ (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT
+ (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT
+ (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a')
+ SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT
+ (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
+ (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
+ -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+
+instance TestEquality SMTy where testEquality = geq
+instance GEq SMTy where geq = defaultGeq
+instance GShow SMTy where gshowsPrec = defaultGshowsPrec
+
+fromSMTy :: SMTy t -> STy t
+fromSMTy = \case
+ SMTNil -> STNil
+ SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2)
+ SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2)
+ SMTMaybe t -> STMaybe (fromSMTy t)
+ SMTArr n t -> STArr n (fromSMTy t)
+ SMTScal sty -> STScal sty
+
data SScalTy t where
STI32 :: SScalTy TI32
STI64 :: SScalTy TI64
@@ -67,14 +113,21 @@ data SScalTy t where
STBool :: SScalTy TBool
deriving instance Show (SScalTy t)
-instance TestEquality SScalTy where
- testEquality STI32 STI32 = Just Refl
- testEquality STI64 STI64 = Just Refl
- testEquality STF32 STF32 = Just Refl
- testEquality STF64 STF64 = Just Refl
- testEquality STBool STBool = Just Refl
- testEquality _ _ = Nothing
-
+instance GCompare SScalTy where
+ gcompare = \cases
+ STI32 STI32 -> GEQ
+ STI32 _ -> GLT ; _ STI32 -> GGT
+ STI64 STI64 -> GEQ
+ STI64 _ -> GLT ; _ STI64 -> GGT
+ STF32 STF32 -> GEQ
+ STF32 _ -> GLT ; _ STF32 -> GGT
+ STF64 STF64 -> GEQ
+ STF64 _ -> GLT ; _ STF64 -> GGT
+ STBool STBool -> GEQ
+ -- STBool _ -> GLT ; _ STBool -> GGT
+
+instance TestEquality SScalTy where testEquality = geq
+instance GEq SScalTy where geq = defaultGeq
instance GShow SScalTy where gshowsPrec = defaultGshowsPrec
scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
@@ -89,50 +142,6 @@ type TIx = TScal TI64
tIx :: STy TIx
tIx = STScal STI64
-unSTy :: STy t -> Ty
-unSTy = \case
- STNil -> TNil
- STPair a b -> TPair (unSTy a) (unSTy b)
- STEither a b -> TEither (unSTy a) (unSTy b)
- STMaybe t -> TMaybe (unSTy t)
- STArr n t -> TArr (unSNat n) (unSTy t)
- STScal t -> TScal (unSScalTy t)
- STAccum t -> TAccum (unSTy t)
-
-unSEnv :: SList STy env -> [Ty]
-unSEnv SNil = []
-unSEnv (SCons t l) = unSTy t : unSEnv l
-
-unSScalTy :: SScalTy t -> ScalTy
-unSScalTy = \case
- STI32 -> TI32
- STI64 -> TI64
- STF32 -> TF32
- STF64 -> TF64
- STBool -> TBool
-
-reSTy :: Ty -> Some STy
-reSTy = \case
- TNil -> Some STNil
- TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b'
- TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b'
- TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t'
- TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t'
- TScal t | Some t' <- reSScalTy t -> Some $ STScal t'
- TAccum t | Some t' <- reSTy t -> Some $ STAccum t'
-
-reSEnv :: [Ty] -> Some (SList STy)
-reSEnv [] = Some SNil
-reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env)
-
-reSScalTy :: ScalTy -> Some SScalTy
-reSScalTy = \case
- TI32 -> Some STI32
- TI64 -> Some STI64
- TF32 -> Some STF32
- TF64 -> Some STF64
- TBool -> Some STBool
-
type family ScalRep t where
ScalRep TI32 = Int32
ScalRep TI64 = Int64
@@ -161,11 +170,12 @@ type family ScalIsIntegral t where
ScalIsIntegral TF64 = False
ScalIsIntegral TBool = False
--- | Returns true for arrays /and/ accumulators;
+-- | Returns true for arrays /and/ accumulators.
hasArrays :: STy t' -> Bool
hasArrays STNil = False
hasArrays (STPair a b) = hasArrays a || hasArrays b
hasArrays (STEither a b) = hasArrays a || hasArrays b
+hasArrays (STLEither a b) = hasArrays a || hasArrays b
hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 0da1afc..48dd709 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -1,17 +1,22 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus) where
+module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import AST
-import CHAD.Types
+import AST.Sparse.Types
import Data
+-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
+-- expanding them into their concrete implementations. Also ensure that
+-- 'EAccum' has a dense sparsity.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
- EZero _ t -> zero t
+ EZero _ t e -> zero t e
+ EDeepZero _ t e -> deepZero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -27,6 +32,10 @@ unMonoid = \case
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unMonoid e)
+ ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
@@ -42,96 +51,200 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
-zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
-zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
-zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t))
-zero (STArr n t) = ENothing ext (STArr n (d2 t))
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+-- don't destroy the effects!
+zero SMTNil e = ELet ext e $ ENil ext
+zero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
+zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
+zero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
-plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero SMTNil e = elet e $ ENil ext
+deepZero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+deepZero (SMTLEither t1 t2) e =
+ elcase e
+ (ELNil ext (fromSMTy t1) (fromSMTy t2))
+ (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
+ (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+deepZero (SMTMaybe t) e =
+ emaybe e
+ (ENothing ext (fromSMTy t))
+ (EJust ext (deepZero t (evar IZ)))
+deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
+deepZero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
+plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
+-- don't destroy the effects!
+plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
+plus (SMTPair t1 t2) a b =
+ let t = STPair (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+plus (SMTLEither t1 t2) a b =
+ let t = STLEither (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t IZ)
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
+ (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
(EError ext t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
(EError ext t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus (STMaybe t) a b =
- plusSparse (d2 t) a b $
- plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
-plus (STArr n t) a b =
- plusSparse (STArr n (d2 t)) a b $
- eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ)
- (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (EVar ext (STArr n (d2 t)) IZ)))
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
-
-plusSparse :: STy a
- -> Ex env (TMaybe a) -> Ex env (TMaybe a)
- -> Ex (a : a : env) a
- -> Ex env (TMaybe a)
-plusSparse t a b adder =
+ (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
+plus (SMTMaybe t) a b =
ELet ext b $
EMaybe ext
- (EVar ext (STMaybe t) IZ)
+ (EVar ext (STMaybe (fromSMTy t)) IZ)
(EJust ext
(EMaybe ext
- (EVar ext t IZ)
- (weakenExpr (WCopy (WCopy WSink)) adder)
- (EVar ext (STMaybe t) (IS IZ))))
+ (EVar ext (fromSMTy t) IZ)
+ (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
(weakenExpr WSink a)
+plus (SMTArr _ t) a b =
+ ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ a b
+plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) -> arg
+ (_, SAPHere) ->
+ ELet ext arg $
+ EVar ext (fromSMTy typ) IZ
- (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
- (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
+ (SMTPair t1 t2, SAPFst prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t1 in
+ EPair ext (EVar ext toh IZ)
+ (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
+
+ (SMTPair t1 t2, SAPSnd prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t2 in
+ EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
+ (EVar ext toh IZ)
- (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
- (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
+ (SMTLEither t1 t2, SAPLeft prj) ->
+ ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ (SMTLEither t1 t2, SAPRight prj) ->
+ ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
- (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
+ (SMTMaybe t1, SAPJust prj) ->
+ EJust ext (onehot t1 prj idx arg)
- (STArr n t1, SAPArrIdx prj _) ->
+ (SMTArr n t1, SAPArrIdx prj) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
- EJust ext $
- EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (zero t1)
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
+ zero t1 (EVar ext (tZeroInfo t1) IZ))
+
+accumulateSparse
+ :: SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse topty topsp arg accum = case (topty, topsp) of
+ (_, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere (ENil ext) arg
+ (SMTScal _, SpScal) ->
+ accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (_, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, SpAbsent) ->
+ ENil ext
+ (SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
+ ENil ext
+
+acPrjCompose
+ :: SAIDense dense
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
+ -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
+acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
+acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPFst p') idx'
+acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPSnd p') idx'
+acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
+acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
+acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPLeft p') idx'
+acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPRight p') idx'
+acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPJust p') idx'
+acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index dbb37f7..d882e28 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -36,6 +36,15 @@ splitIdx SNil i = Right i
splitIdx (SCons _ _) IZ = Left IZ
splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
+slistIdx :: SList f list -> Idx list t -> f t
+slistIdx (SCons x _) IZ = x
+slistIdx (SCons _ list) (IS i) = slistIdx list i
+slistIdx SNil i = case i of {}
+
+idx2int :: Idx env t -> Int
+idx2int IZ = 0
+idx2int (IS n) = 1 + idx2int n
+
data env :> env' where
WId :: env :> env
WSink :: forall t env. env :> (t : env)
@@ -117,3 +126,7 @@ wCopies bs w =
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
+
+wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2
+wPops SNil w = w
+wPops (_ `SCons` bs) w = wPops bs (WPop w)
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index 6752c24..c6efe37 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where
SSegNil :: SSegments '[]
SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
-instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where
+instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where
fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil
auto :: KnownListSpine list => SList (Const ()) list
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 095d0fa..b54946b 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -2,8 +2,12 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
module Analysis.Identity (
identityAnalysis,
+ identityAnalysis',
+ ValId(..),
+ validSplitEither,
) where
import Data.Foldable (toList)
@@ -24,11 +28,13 @@ data ValId t where
VIPair :: ValId a -> ValId b -> ValId (TPair a b)
VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
+deriving instance Show (ValId t)
instance PrettyX ValId where
prettyX = \case
@@ -40,16 +46,31 @@ instance PrettyX ValId where
VIMaybe Nothing -> "N"
VIMaybe (Just a) -> 'J' : prettyX a
VIMaybe' a -> 'M' : prettyX a
+ VILEither (VIMaybe Nothing) -> "lN"
+ VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")"
+ VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))"
VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
VIScal i -> show i
VIAccum i -> 'C' : show i
+validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b))
+validSplitEither (VIEither (Left v)) = (Just v, Nothing)
+validSplitEither (VIEither (Right v)) = (Nothing, Just v)
+validSplitEither (VIEither' v1 v2) = (Just v1, Just v2)
+
-- | Symbolic partial evaluation.
identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
identityAnalysis env term = runIdGen 0 $ do
env' <- slistMapA genIds env
snd <$> idana env' term
+identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t
+identityAnalysis' env term = snd (runIdGen 0 (idana env term))
+
idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
idana env expr = case expr of
EVar _ t i -> do
@@ -103,9 +124,9 @@ idana env expr = case expr of
(v3, e3') <- idana (v1' `SCons` env) e3
pure (v3, ECase v3 e1' e2' e3')
VIEither' v1'l v1'r -> do
- (_, e2') <- idana (v1'l `SCons` env) e2
- (_, e3') <- idana (v1'r `SCons` env) e3
- res <- genIds (typeOf expr)
+ (v2, e2') <- idana (v1'l `SCons` env) e2
+ (v3, e3') <- idana (v1'r `SCons` env) e3
+ res <- unify v2 v3
pure (res, ECase res e1' e2' e3')
ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t)
@@ -134,6 +155,42 @@ idana env expr = case expr of
res <- unify v1 v2
pure (res, EMaybe res e1' e2' e3')
+ ELNil _ t1 t2 -> do
+ let v = VILEither (VIMaybe Nothing)
+ pure (v, ELNil v t1 t2)
+
+ ELInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VILEither (VIMaybe (Just (VIEither (Left v1))))
+ pure (v, ELInl v t2 e1')
+
+ ELInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VILEither (VIMaybe (Just (VIEither (Right v2))))
+ pure (v, ELInr v t1 e2')
+
+ ELCase _ e1 e2 e3 e4 -> do
+ let STLEither t1 t2 = typeOf e1
+ (v1L, e1') <- idana env e1
+ let VILEither v1 = v1L
+ let go mv1'l mv1'r f = do
+ v1'l <- maybe (genIds t1) pure mv1'l
+ v1'r <- maybe (genIds t2) pure mv1'r
+ (v2, e2') <- idana env e2
+ (v3, e3') <- idana (v1'l `SCons` env) e3
+ (v4, e4') <- idana (v1'r `SCons` env) e4
+ res <- f v2 v3 v4
+ pure (res, ELCase res e1' e2' e3' e4')
+ case v1 of
+ VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2)
+ VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3)
+ VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4)
+ VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4)
+ VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3)
+ VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4)
+ VIMaybe' (VIEither' v1'l v1'r) ->
+ go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4)
+
EConstArr _ dim t arr -> do
x1 <- VIArr <$> genId <*> vecReplicateA dim genId
pure (x1, EConstArr x1 dim t arr)
@@ -237,6 +294,10 @@ idana env expr = case expr of
res <- genIds t4
pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
+ ERecompute _ e -> do
+ (v, e') <- idana env e
+ pure (v, ERecompute v e')
+
EWith _ t e1 e2 -> do
let t1 = typeOf e1
(_, e1') <- idana env e1
@@ -246,26 +307,36 @@ idana env expr = case expr of
let res = VIPair v2 x2
pure (res, EWith res t e1' e2')
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
- pure (VINil, EAccum VINil t prj e1' e2' e3')
+ pure (VINil, EAccum VINil t prj e1' sp e2' e3')
- EZero _ t -> do
- res <- genIds (d2 t)
- pure (res, EZero res t)
+ EZero _ t e1 -> do
+ -- Approximate the result of EZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EZero res t e1')
+
+ EDeepZero _ t e1 -> do
+ -- Approximate the result of EDeepZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EDeepZero res t e1')
EPlus _ t e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (d2 t)
+ res <- genIds (fromSMTy t)
pure (res, EPlus res t e1' e2')
EOneHot _ t i e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (d2 t)
+ res <- genIds (fromSMTy t)
pure (res, EOneHot res t i e1' e2')
EError _ t s -> do
@@ -294,6 +365,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VILEither a) (VILEither b) = VILEither <$> unify a b
unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
@@ -306,6 +378,7 @@ genIds :: STy t -> IdGen (ValId t)
genIds STNil = pure VINil
genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
diff --git a/src/CHAD.hs b/src/CHAD.hs
index be308cd..143376a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -11,6 +11,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -32,19 +33,23 @@ module CHAD (
) where
import Data.Functor.Const
-import Data.Type.Bool (If)
-import Data.Type.Equality (type (==))
+import Data.Some
+import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
+import Analysis.Identity (ValId(..), validSplitEither)
import AST
import AST.Bindings
import AST.Count
import AST.Env
+import AST.Sparse
import AST.Weaken.Auto
import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
+import qualified Data.VarMap as VarMap
+import Data.VarMap (VarMap)
import Lemmas
@@ -58,14 +63,20 @@ tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
-bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds
- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-bindingsCollect BTop SETop _ = ENil ext
-bindingsCollect (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
+ -> binds :> env2 -> Ex env2 (Tape tapebinds)
+bindingsCollectTape SNil SETop _ = ENil ext
+bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
EPair ext (EVar ext t (w @> IZ))
- (bindingsCollect binds sub (w .> WSink))
-bindingsCollect (BPush binds _) (SENo sub) w =
- bindingsCollect binds sub (w .> WSink)
+ (bindingsCollectTape binds sub (w .> WSink))
+bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
+ bindingsCollectTape binds sub (w .> WSink)
+
+-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
+-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
+-- bindingsCollectTape' binds sub w
+-- | Refl <- lemAppendNil @binds
+-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))
-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
@@ -215,6 +226,7 @@ d1op (ORecip t) e = EOp ext (ORecip t) e
d1op (OExp t) e = EOp ext (OExp t) e
d1op (OLog t) e = EOp ext (OLog t) e
d1op (OIDiv t) e = EOp ext (OIDiv t) e
+d1op (OMod t) e = EOp ext (OMod t) e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -222,25 +234,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
-> D2Op (TScal a) t
@@ -255,11 +279,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -286,86 +310,182 @@ conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
data Idx2 env sto t
- = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t))
- | Idx2Me (Idx (Select env sto "merge") t)
+ = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
| Idx2Di (Idx (Select env sto "discr") t)
conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
-conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ
-conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ
-conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ
-conv2Idx (DPush des (_, SAccum)) (IS i) =
+conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ
+conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ
+conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ
+conv2Idx (DPush des (_, _, SAccum)) (IS i) =
case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)
Idx2Me j -> Idx2Me j
Idx2Di j -> Idx2Di j
-conv2Idx (DPush des (_, SMerge)) (IS i) =
+conv2Idx (DPush des (_, _, SMerge)) (IS i) =
case conv2Idx des i of Idx2Ac j -> Idx2Ac j
Idx2Me j -> Idx2Me (IS j)
Idx2Di j -> Idx2Di j
-conv2Idx (DPush des (_, SDiscr)) (IS i) =
+conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
case conv2Idx des i of Idx2Ac j -> Idx2Ac j
Idx2Me j -> Idx2Me j
Idx2Di j -> Idx2Di (IS j)
conv2Idx DTop i = case i of {}
-
------------------------------------- MONOIDS -----------------------------------
-
-zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
-zeroTup SNil = ENil ext
-zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t)
-
-
------------------------------------- SUBENVS -----------------------------------
-
-subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
+opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+opt2UnSparse = go . opt2
+ where
+ go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+ go (STScal STI32) SpAbsent = \_ -> ENil ext
+ go (STScal STI64) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
+ go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ go (STScal STBool) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpScal = id
+ go (STScal STF64) SpScal = id
+ go STNil _ = \_ -> ENil ext
+ go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
+ go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
+
+
+----------------------------------- SPARSITY -----------------------------------
+
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal STF32) SpScal _ e = e
+expandSparse (STScal STF64) SpScal _ e = e
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+
+subenvPlus :: SBool req1 -> SBool req2
+ -> SList SMTy env
+ -> SubenvS env env1 -> SubenvS env env2
+ -> (forall env3. SubenvS env env3
+ -> Injection req1 (Tup env1) (Tup env3)
+ -> Injection req2 (Tup env2) (Tup env3)
+ -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
-> r)
-> r
-subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
-subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
+
+subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e2 $
- EPair ext (pl (weakenExpr WSink e1)
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e1 $
- ELet ext (weakenExpr WSink e2) $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (EPlus ext t
- (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ)))
-
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t)
-
-assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
+
+subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ Noinj
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+
+subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
+ subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
+ k sub3 minj13 minj23 (flip pl)
+
+subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
+ k (SEYes sp3 sub3)
+ (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (tinj13 e1b))
+ (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
+ \e2 -> eunPair e2 $ \_ e2a e2b ->
+ EPair ext (inj23 e2a) (tinj23 e2b))
+ (\e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))))
+
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
@@ -373,6 +493,10 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
+fromArrayValId :: Maybe (ValId t) -> Maybe Int
+fromArrayValId (Just (VIArr i _)) = Just i
+fromArrayValId _ = Nothing
+
accumPromote :: forall dt env sto proxy r.
proxy dt
-> Descr env sto
@@ -381,8 +505,7 @@ accumPromote :: forall dt env sto proxy r.
=> Descr env stoRepl
-- ^ A revised environment description that switches
-- arrays (used in the OccEnv) that are currently on
- -- "merge" storage, to "accum" storage. Any other "merge"
- -- entries are deleted.
+ -- "merge" storage, to "accum" storage.
-> SList STy envPro
-- ^ New entries on top of the original dual environment,
-- that house the accumulators for the promoted arrays in
@@ -390,72 +513,92 @@ accumPromote :: forall dt env sto proxy r.
-> Subenv (Select env sto "merge") envPro
-- ^ The promoted entries were merge entries in the
-- original environment.
+ -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum"))
+ -- ^ All entries that were accumulators are still
+ -- accumulators.
+ -> VarMap Int (D2AcE (Select env stoRepl "accum"))
+ -- ^ Accumulator map for _only_ the the newly allocated
+ -- accumulators.
-> (forall shbinds.
SList STy shbinds
- -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
- :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
-- ^ A weakening that converts a computation in the
-- revised environment to one in the original environment
-- extended with some accumulators.
-> r)
-> r
-accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId)
-accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
- accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf ->
- case sto of
- -- Accumulators are left as-is
- SAccum ->
- k (storepl `DPush` (t, SAccum))
- envpro
- prosub
- (\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr)))
- (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
- (#pro :++: #d :++: #shb :++: #acc :++: #tl)
- .> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl)))
- (#d :++: #shb :++: #acc :++: #tl)
- (#acc :++: (#d :++: #shb :++: #tl)))
-
- SMerge -> case t of
- -- Discrete values are left as-is
- _ | isDiscrete t ->
- k (storepl `DPush` (t, SDiscr))
- envpro
- (SENo prosub)
- wf
-
- -- Values with "merge" storage are promoted to an accumulator in envPro
- _ ->
- k (storepl `DPush` (t, SAccum))
- (t `SCons` envpro)
- (SEYes prosub)
- (\(shbinds :: SList _ shbinds) ->
- let shbindsC = slistMap (\_ -> Const ()) shbinds
- in
- -- wf:
- -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WCopy wf:
- -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WPICK: ^ THESE TWO ||
- -- goal: | ARE EQUAL ||
- -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- WCopy (wf shbinds)
- .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
- (WId @(D2AcE (Select env1 stoRepl "accum"))))
-
- -- Discrete values are left as-is, nothing to do
- SDiscr ->
- k (storepl `DPush` (t, SDiscr))
+accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId)
+accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
+ -- Accumulators are left as-is
+ SAccum ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ envpro
+ prosub
+ (SEYesR accrevsub)
+ (VarMap.sink1 accumMap)
+ (\shbinds ->
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
+ (#pro :++: #d :++: #shb :++: #acc :++: #tl)
+ .> WCopy (wf shbinds)
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ (#d :++: #shb :++: #acc :++: #tl)
+ (#acc :++: (#d :++: #shb :++: #tl)))
+
+ SMerge -> case t of
+ -- Discrete values are left as-is
+ _ | isDiscrete t ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
envpro
- prosub
+ (SENo prosub)
+ accrevsub
+ accumMap'
wf
+
+ -- Values with "merge" storage are promoted to an accumulator in envPro
+ _ ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ (t `SCons` envpro)
+ (SEYesR prosub)
+ (SENo accrevsub)
+ (let accumMap' = VarMap.sink1 accumMap
+ in case fromArrayValId vid of
+ Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap'
+ Nothing -> accumMap')
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ -- Discrete values are left as-is, nothing to do
+ SDiscr ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
+ envpro
+ prosub
+ accrevsub
+ accumMap
+ wf
where
isDiscrete :: STy t' -> Bool
isDiscrete = \case
STNil -> True
STPair a b -> isDiscrete a && isDiscrete b
STEither a b -> isDiscrete a && isDiscrete b
+ STLEither a b -> isDiscrete a && isDiscrete b
STMaybe a -> isDiscrete a
STArr _ a -> isDiscrete a
STScal st -> case st of
@@ -469,21 +612,41 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
-data Ret env0 sto t =
- forall shbinds tapebinds env0Merge.
+data Ret env0 sto sd t =
+ forall shbinds tapebinds contribs.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (Ret env0 sto t)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+deriving instance Show (Ret env0 sto sd t)
-data RetPair env0 sto env shbinds tapebinds t =
- forall env0Merge.
- RetPair (Ex (Append shbinds env) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
+type data TyTyPair = MkTyTyPair Ty Ty
+
+data SingleRet env0 sto (pair :: TyTyPair) =
+ forall shbinds tapebinds.
+ SingleRet
+ (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (RetPair env0 sto (D1E env0) shbinds tapebinds pair)
+
+-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
+-- -> Subenv shbinds tapebinds
+-- -> Ex (Append shbinds (D1E env0)) (D1 t)
+-- -> SubenvS (D2E (Select env0 sto "merge")) contribs
+-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+-- -> SingleRet env0 sto (MkTyTyPair sd t)
+-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
+-- {-# COMPLETE Ret1 #-}
+
+data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
+ RetPair :: forall sd t contribs -- existentials
+ env0 sto env shbinds tapebinds. -- universals
+ Ex (Append shbinds env) (D1 t)
+ -> SubenvS (D2E (Select env0 sto "merge")) contribs
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
+deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
data Rets env0 sto env list =
forall shbinds tapebinds.
@@ -492,8 +655,11 @@ data Rets env0 sto env list =
(SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)
+toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
+toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+
weakenRetPair :: SList STy shbinds -> env :> env'
- -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t
+ -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
@@ -501,104 +667,137 @@ weakenRets w (Rets binds tapesub list) =
let (binds', _) = weakenBindings weakenExpr w binds
in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f.
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
Descr env0 sto
-> SList f b1 -> SList f b2
-> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
- -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t
- -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t
-rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (autoWeak
- (#d (auto1 @(D2 t))
- &. #t2 (subList b2 subtape2)
- &. #t1 (subList b1 subtape1)
- &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#t2 :++: #tl))
- (#d :++: ((#t2 :++: #t1) :++: #tl)))
- d)
-
-retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
+ RetPair e1 sub
+ (weakenExpr (autoWeak
+ (#d (auto1 @sd)
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ e2)
+
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat _ SNil = Rets BTop SETop SNil
-retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list)
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
| Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
- <- weakenRets (sinkWithBindings b) (retConcat descr list)
+ <- weakenRets (sinkWithBindings e0) (retConcat descr list)
, Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
, Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
- = Rets (bconcat b binds)
+ = Rets (bconcat e0 binds)
(subenvConcat subtape subtape2)
- (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
sub
- (weakenExpr (WCopy (sinkWithSubenv subtape2)) d))
- (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
subtape subtape2)
pairs))
freezeRet :: Descr env sto
- -> Ret env sto t
+ -> Ret env sto (D2 t) t
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tape (subList (bindingsBinds e0) subtape)
- &. #shbinds (bindingsBinds e0)
- &. #d2ace (d2ace (select SAccum descr))
- &. #tl (desD1E descr))
+ (ELet ext (weakenExpr (autoWeak library
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
- expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
---------------------------- THE CHAD TRANSFORMATION ---------------------------
-drev :: forall env sto t.
+drev :: forall env sto sd t.
(?config :: CHADConfig)
- => Descr env sto
- -> Ex env t -> Ret env sto t
-drev des = \case
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2 t) sd
+ -> Expr ValId env t -> Ret env sto sd t
+drev des _ sd | isAbsent sd =
+ \e ->
+ Ret BTop
+ SETop
+ (drevPrimal des e)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+drev _ _ SpAbsent = error "Absent should be isAbsent"
+
+drev des accumMap (SpSparse sd) =
+ \e ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ e1
+ sub'
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
+ (inj2 (ENil ext))
+ (inj1 (weakenExpr (WCopy WSink) e2)))
+ }
+
+drev des accumMap sd = \case
EVar _ t i ->
case conv2Idx des i of
Idx2Ac accI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
- (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI)))
+ (subenvNone (d2e (select SMerge des)))
+ (let ty = applySparse sd (d2M t)
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvOnehot (select SMerge des) tupI)
- (EPair ext (ENil ext) (EVar ext (d2 t) IZ))
+ (subenvOnehot (d2e (select SMerge des)) tupI sd)
+ (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
Idx2Di _ ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
- ELet _ (rhs :: Ex _ a) body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
- , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body
+ ELet _ (rhs :: Expr _ _ a) body
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
- subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
+ , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
+ , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
+ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
- (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
+ (subenvConcat subtapeRHS subtapeBody)
(weakenExpr wbody0' body1)
subBoth
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds body0) subtapeBody)
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
&. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
@@ -608,328 +807,374 @@ drev des = \case
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EVar ext (contribTupTy des subRHS) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
- | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil
- , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
- subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
+ | SpPair sd1 sd2 <- sd
+ , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
+ , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
subtape
(EPair ext a1 b1)
subBoth
- (EMaybe ext
- (zeroTup (subList (select SMerge des) subBoth))
- (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
- (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
- (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (contribTupTy des subA) (IS IZ))
+ (EVar ext (contribTupTy des subB) IZ))
EFst _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
+ , STPair t1 _ <- typeOf e ->
Ret e0
subtape
(EFst ext e1)
sub
- (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $
+ (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
+ , STPair _ t2 <- typeOf e ->
Ret e0
subtape
(ESnd ext e1)
sub
- (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
+ -- Don't need to handle ENil, because its cotangent is always absent!
+ -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)
EInl _ t2 e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInl ext (d1 t2) e1)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
- (weakenExpr (WCopy (wSinks' @[_,_])) e2)
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
- (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))
+ sub'
+ (ELCase ext
+ (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
+ (inj2 $ ENil ext)
+ (inj1 $ weakenExpr (WCopy WSink) e2)
+ (EError ext (contribTupTy des sub') "inl<-dinr"))
EInr _ t1 e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInr ext (d1 t1) e1)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
- (weakenExpr (WCopy (wSinks' @[_,_])) e2))
- (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))
-
- ECase _ e (a :: Ex _ t) b
- | STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
+ sub'
+ (ELCase ext
+ (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
+ (inj2 $ ENil ext)
+ (EError ext (contribTupTy des sub') "inr<-dinl")
+ (inj1 $ weakenExpr (WCopy WSink) e2))
+
+ ECase _ e (a :: Expr _ _ t) b
+ | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b
+ , let (bindids1, bindids2) = validSplitEither (extOf e)
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
+ <- drevScoped des accumMap t1 storage1 bindids1 sd a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
+ <- drevScoped des accumMap t2 storage2 bindids2 sd b
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
- , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
- , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB)
- , let collectA = bindingsCollect a0 subtapeA
- , let collectB = bindingsCollect b0 subtapeB
+ , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , let tapeA = tapeTy subtapeListA
+ , let tapeB = tapeTy subtapeListB
+ , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
+ (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
+ (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
, let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
+ , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
+ , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
+ , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
+ , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
+ , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
+ , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
->
- subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ ->
- subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
- let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in
+ subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
Ret (e0 `BPush`
(tPrimal,
ECase ext e1
- (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
- (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
- (SEYes subtapeE)
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))))
+ (SEYesR subtapeE)
(EFst ext (EVar ext tPrimal IZ))
subOut
- (ELet ext
+ (elet
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ
in letBinds rebinds $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #ta0 (subList (bindingsBinds a0) subtapeA)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #ta0 subtapeListA
&. #prea0 prerebinds
- &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #ta0 :++: #tl)
(#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))
- (EInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ
+ EPair ext (sAB_A $ EFst ext (evar IZ))
+ (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ
in letBinds rebinds $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tb0 (subList (bindingsBinds b0) subtapeB)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #tb0 subtapeListB
&. #preb0 prerebinds
- &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #tb0 :++: #tl)
(#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))
- (EInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
- ELet ext
- (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $
- weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
+ EPair ext (sAB_B $ EFst ext (evar IZ))
+ (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
plus_AB_E
- (EFst ext (EVar ext tCaseRet (IS IZ)))
- (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))
+ (EFst ext (evar IZ))
+ (ELet ext (ESnd ext (evar IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2))
EConst _ t val ->
Ret BTop
SETop
(EConst ext t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EOp _ op e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
case d2op op of
Linear d2opfun ->
Ret e0
subtape
(d1op op e1)
sub
- (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))
+ (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy WSink) e2))
Nonlinear d2opfun ->
Ret (e0 `BPush` (d1 (typeOf e), e1))
- (SEYes subtape)
+ (SEYesR subtape)
(d1op op $ EVar ext (d1 (typeOf e)) IZ)
sub
(ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
- (EVar ext (d2 (opt2 op)) IZ))
+ (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
- ECustom _ _ _ storety _ pr du a b
+ ECustom _ _ tb storety srce pr du a b
-- allowed to ignore a2 because 'a' is the part of the input that is inactive
- | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
- <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil ->
- Ret (binds `BPush` (typeOf a1, a1)
- `BPush` (typeOf b1, weakenExpr WSink b1)
- `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr)
- `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
- (SEYes (SENo (SENo (SENo subtape))))
- (EFst ext (EVar ext (typeOf pr) (IS IZ)))
- bsub
- (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $
- weakenExpr (WCopy (WSink .> WSink)) b2)
+ | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
+ case isDense (d2M (typeOf srce)) sd of
+ Just Refl ->
+ Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
+ `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
+ (SEYesR (SENo (SENo (SENo bsubtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
+ Nothing ->
+ Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)))
+ (SEYesR (SENo (SENo bsubtape)))
+ (EFst ext (EVar ext (typeOf pr) IZ))
+ bsub
+ (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape
+ ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent
+ (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
+ (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
+ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)
+
+ ERecompute _ e ->
+ deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
+ let smallE = unsafeWeakenWithSubenv usedSub e in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
+ let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
+ Ret (collectBindings (desD1E des) subD1eUsed)
+ (subenvAll (desD1E usedDes))
+ (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
+ (subenvCompose subMergeUsed' sub)
+ (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
+ weakenExpr
+ (autoWeak (#d (auto1 @sd)
+ &. #shbinds (bindingsBinds e0)
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #d1env (desD1E usedDes)
+ &. #tl' (d2ace (select SAccum usedDes))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
+ (#shbinds :++: #d :++: #d1env :++: #tl))
+ e2)
+ }
EError _ t s ->
Ret BTop
SETop
(EError ext (d1 t) s)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EConstArr _ n t val ->
Ret BTop
SETop
(EConstArr ext n t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
- | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result
+ EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
+ | SpArr @_ @sdElt sdElt <- sd
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
- subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
- case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
+ case lemAppendNil @e_binds of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollect e0 subtapeE in
- Ret (BTop `BPush` (shty, letBinds she0 she1)
- `BPush` (STArr ndim (STPair (d1 eltty) tapety)
- ,EBuild ext ndim
- (EVar ext shty IZ)
- (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#ix :++: #sh :++: #d1env))
- e0)) $
- let w = autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #e0 (bindingsBinds e0)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#e0 :++: #ix :++: #sh :++: #d1env)
- in EPair ext (weakenExpr w e1) (collectexpr w)))
- `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
- (SEYes (SENo (SEYes SETop)))
- (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
- (subenvCompose subMergeUsed proSub)
- (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- EMaybe ext
- (zeroTup envPro)
- (ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
- weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (auto1 @(Tape e_tape))
- &. #ix (auto1 @shty)
- &. #darr (auto1 @(TArr ndim (D2 eltty)))
- &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty))))
- &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
- &. #sh (auto1 @shty)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
- (EVar ext (d2 (STArr ndim eltty)) IZ))
- }}
+ let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ Ret (mergePrimalBindings
+ `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she))
+ `BPush` (STArr ndim (STPair (d1 eltty) tapety)
+ ,EBuild ext ndim
+ (EVar ext shty IZ)
+ (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#ix :++: #sh :++: #propr :++: #d1env))
+ e0)) $
+ let w = autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env)
+ w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
+ in EPair ext (weakenExpr w e1) (collectexpr w')))
+ `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
+ (SEYesR (SENo (SEYesR (subenvAll (d1e envPro)))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in
+ ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#d (auto1 @sdElt)
+ &. #pro (d2ace envPro)
+ &. #etape (subList (bindingsBinds e0) subtapeE)
+ &. #prerebinds prerebinds
+ &. #tape (auto1 @(Tape e_tape))
+ &. #ix (auto1 @shty)
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
+ &. #sh (auto1 @shty)
+ &. #propr (d1e envPro)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds e0) subtapeE))
+ e2)
+ }}}
EUnit _ e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | SpArr sdElt <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
Ret e0
subtape
(EUnit ext e1)
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EReplicate1Inner _ en e
- -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
- <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil
+ -- We're allowed to differentiate 'en' as primal-only here because its output is discrete.
+ | SpArr sdElt <- sd
, let STArr ndim eltty = typeOf e ->
- Ret binds
- subtape
- (EReplicate1Inner ext en1 e1)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EFold1Inner ext Commut
- (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (EZero ext eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
+ -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero.
+ sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 ->
+ Ret binds
+ subtape
+ (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
+ sub
+ (ELet ext (EFold1Inner ext Commut
+ (sparsePlus (d2M eltty) sdElt'
+ (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
+ (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (inj2 (ENil ext))
+ (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+ }
EIdx0 _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e
, STArr _ t <- typeOf e ->
Ret e0
subtape
(EIdx0 ext e1)
sub
- (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
- weakenExpr (WCopy WSink) e2)
+ (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
{-
EIdx1 _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
, STArr (SS n) eltty <- typeOf e ->
Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
(weakenExpr (WSink .> WSink) ei1))
sub
@@ -940,55 +1185,58 @@ drev des = \case
-}
EIdx _ e ei
- -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
- , STArr n eltty <- typeOf e
+ -- We're allowed to differentiate ei as primal because its output is discrete.
+ | STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n
- , let tIxN = tTup (sreplicate n tIx) ->
- Ret (binds `BPush` (STArr n (d1 eltty), e1)
- `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
- `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
- (SEYes (SEYes (SENo subtape)))
- (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- sub
- (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n)
- (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ))))
- (ENil ext))
- (EVar ext (d2 eltty) IZ)) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ , let tIxN = tTup (sreplicate n tIx) ->
+ sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 ->
+ Ret (binds `BPush` (STArr n (d1 eltty), e1)
+ `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
+ `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)))
+ (SEYesR (SEYesR (SENo subtape)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ sub
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
EShape _ e
- -- Allowed to ignore e2 here because the output of EShape is discrete,
- -- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des e
- , STArr n _ <- typeOf e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret e0
- subtape
- (EShape ext e1)
- (subenvNone (select SMerge des))
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 (STArr n t)) IZ))
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -996,10 +1244,14 @@ drev des = \case
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"
+ ELNil{} -> err_unsupported "ELNil"
+ ELInl{} -> err_unsupported "ELInl"
+ ELInr{} -> err_unsupported "ELInr"
+ ELCase{} -> err_unsupported "ELCase"
EWith{} -> err_accum
- EAccum{} -> err_accum
EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
@@ -1008,68 +1260,116 @@ drev des = \case
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
- deriv_extremum :: ScalIsNumeric t' ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
- deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des e
- , at@(STArr (SS n) t@(STScal st)) <- typeOf e
- , let at' = STArr n t
- , let tIxN = tTup (sreplicate (SS n) tIx) =
- Ret (e0 `BPush` (at, e1)
- `BPush` (at', extremum (EVar ext at IZ)))
- (SEYes (SEYes subtape))
- (EVar ext at' IZ)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext
- (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
- eif (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
- (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
- (EZero ext t))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 at') IZ))
+ contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
+ contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
-data RetScoped env0 sto a s t =
- forall shbinds tapebinds env0Merge.
+data RetScoped env0 sto a s sd t =
+ forall shbinds tapebinds contribs sa.
RetScoped
(Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
- (Subenv shbinds tapebinds)
+ (Subenv (Append shbinds '[D1 a]) tapebinds)
(Ex (Append shbinds (D1E (a : env0))) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
-- ^ merge contributions to the _enclosing_ merge environment
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
- (If (s == "discr") (Tup (D2E env0Merge))
- (TPair (Tup (D2E env0Merge)) (D2 a))))
+ (Sparse (D2 a) sa)
+ -- ^ contribution to the argument
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) sa)))
-- ^ the merge contributions, plus the cotangent to the argument
-- (if there is any)
-deriving instance Show (RetScoped env0 sto a s t)
+deriving instance Show (RetScoped env0 sto a s sd t)
-drevScoped :: forall a s env sto t.
+drevScoped :: forall a s env sto sd t.
(?config :: CHADConfig)
- => Descr env sto -> STy a -> Storage s
- -> Ex (a : env) t
- -> RetScoped env sto a s t
-drevScoped des argty argsto expr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr
- = case argsto of
- SMerge ->
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> STy a -> Storage s -> Maybe (ValId a)
+ -> Sparse (D2 t) sd
+ -> Expr ValId (a : env) t
+ -> RetScoped env sto a s sd t
+drevScoped des accumMap argty argsto argids sd expr = case argsto of
+ SMerge
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
case sub of
- SEYes sub' -> RetScoped e0 subtape e1 sub' e2
- SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))
- SAccum ->
- RetScoped e0 subtape e1 sub $
- EWith ext argty (EZero ext argty) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds e0) subtape)
- &. #ac (auto1 @(TAccum a))
- &. #tl (d2ace (select SAccum des)))
+ SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
+ SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
+
+ SAccum
+ | chcSmartWith ?config
+ , Just (VIArr i _) <- argids
+ , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
+ , Just Refl <- testEquality foundTy (STAccum (d2M argty))
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ -- Our contribution to the binding's cotangent _here_ is zero (absent),
+ -- because we're contributing to an earlier binding of the same value
+ -- instead.
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
+ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
+ ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
+ weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: #body :++: #tl))
+ (EPair ext e2 (ENil ext))
+
+ | let accumMap' = case argids of
+ Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
+ _ -> VarMap.sink1 accumMap
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ let library = #d (auto1 @sd)
+ &. #p (auto1 @(D1 a))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des))
+ in
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
+ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ weakenExpr (autoWeak library
(#d :++: #body :++: #ac :++: #tl)
- (#ac :++: #d :++: #body :++: #tl))
+ (#ac :++: #d :++: (#body :++: #p) :++: #tl))
e2
- SDiscr -> RetScoped e0 subtape e1 sub e2
+
+ SDiscr
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+
+-- TODO: proper primal-only transform that doesn't depend on D1 = Id
+drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal des e
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
+ = mapExt (const ext) e
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index b61b5ff..7212232 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -1,18 +1,54 @@
-{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+-- | TODO this module is a grab-bag of random utility functions that are shared
+-- between CHAD and CHAD.Top.
module CHAD.Accum where
import AST
import CHAD.Types
import Data
+import AST.Env
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators SNil e = e
-makeAccumulators (t `SCons` envpro) e =
- makeAccumulators envpro $
- EWith ext t (EZero ext t) e
+d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
+d2deepZeroInfo STNil _ = ENil ext
+d2deepZeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
+d2deepZeroInfo (STEither a b) e =
+ ECase ext e
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STLEither a b) e =
+ elcase e
+ (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STMaybe a) e =
+ emaybe e
+ (ENothing ext (tDeepZeroInfo (d2M a)))
+ (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
+d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
+d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
@@ -25,3 +61,7 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs
index fcd91f7..49ae0e6 100644
--- a/src/CHAD/EnvDescr.hs
+++ b/src/CHAD/EnvDescr.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
@@ -9,10 +10,13 @@
module CHAD.EnvDescr where
import Data.Kind (Type)
+import Data.Some
import GHC.TypeLits (Symbol)
+import Analysis.Identity (ValId(..))
import AST.Env
import AST.Types
+import AST.Weaken
import CHAD.Types
import Data
@@ -27,12 +31,17 @@ deriving instance Show (Storage s)
-- | Environment description
data Descr env sto where
DTop :: Descr '[] '[]
- DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
+ DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto)
deriving instance Show (Descr env sto)
descrList :: Descr env sto -> SList STy env
descrList DTop = SNil
-descrList (des `DPush` (t, _)) = t `SCons` descrList des
+descrList (des `DPush` (t, _, _)) = t `SCons` descrList des
+
+descrPrj :: Descr env sto -> Idx env t -> (STy t, Maybe (ValId t), Some Storage)
+descrPrj (_ `DPush` (ty, vid, sto)) IZ = (ty, vid, Some sto)
+descrPrj (des `DPush` _) (IS i) = descrPrj des i
+descrPrj DTop i = case i of {}
-- | This could have more precise typing on the output storage.
subDescr :: Descr env sto -> Subenv env env'
@@ -43,13 +52,13 @@ subDescr :: Descr env sto -> Subenv env env'
-> r)
-> r
subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
- SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
- SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e)
-subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e)
+ SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e)
+ SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e)
+subDescr (des `DPush` (_, _, sto)) (SENo sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
@@ -64,12 +73,24 @@ type family Select env sto s where
select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
select _ DTop = SNil
-select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
-select s@SMerge (DPush des (_, SAccum)) = select s des
-select s@SDiscr (DPush des (_, SAccum)) = select s des
-select s@SAccum (DPush des (_, SMerge)) = select s des
-select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-select s@SDiscr (DPush des (_, SMerge)) = select s des
-select s@SAccum (DPush des (_, SDiscr)) = select s des
-select s@SMerge (DPush des (_, SDiscr)) = select s des
-select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des)
+select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, _, SAccum)) = select s des
+select s@SDiscr (DPush des (_, _, SAccum)) = select s des
+select s@SAccum (DPush des (_, _, SMerge)) = select s des
+select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des)
+select s@SDiscr (DPush des (_, _, SMerge)) = select s des
+select s@SAccum (DPush des (_, _, SDiscr)) = select s des
+select s@SMerge (DPush des (_, _, SDiscr)) = select s des
+select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des)
+
+selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s)
+selectSub _ DTop = SETop
+selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index d058132..484779e 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -10,13 +10,18 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
+import Analysis.Identity
import AST
+import AST.Env
+import AST.Sparse
+import AST.SplitLets
import AST.Weaken.Auto
import CHAD
import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
+import qualified Data.VarMap as VarMap
type family MergeEnv env where
@@ -25,7 +30,7 @@ type family MergeEnv env where
mergeDescr :: SList STy env -> Descr env (MergeEnv env)
mergeDescr SNil = DTop
-mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge)
+mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge)
mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
mergeEnvNoAccum SNil = Refl
@@ -38,38 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
accumDescr SNil k = k DTop
accumDescr (t `SCons` env) k = accumDescr env $ \des ->
- if hasArrays t then k (des `DPush` (t, SAccum))
- else k (des `DPush` (t, SMerge))
-
-d1Identity :: STy t -> D1 t :~: t
-d1Identity = \case
- STNil -> Refl
- STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STMaybe t | Refl <- d1Identity t -> Refl
- STArr _ t | Refl <- d1Identity t -> Refl
- STScal _ -> Refl
- STAccum{} -> error "Accumulators not allowed in input program"
-
-d1eIdentity :: SList STy env -> D1E env :~: env
-d1eIdentity SNil = Refl
-d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+ if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
+ else k (des `DPush` (t, Nothing, SMerge))
reassembleD2E :: Descr env sto
+ -> D1E env :> env'
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
-reassembleD2E DTop _ = ENil ext
-reassembleD2E (des `DPush` (_, SAccum)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
- (ESnd ext (EVar ext (typeOf e) IZ))))
- (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (_, SMerge)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
- (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
- (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t)
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
@@ -79,21 +71,24 @@ chad config env (term :: Ex env t)
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
- makeAccumulators (select SAccum descr) $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr term)) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
- (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
- (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
+ where
+ term' = identityAnalysis env (splitLets term)
chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' config env term
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 7f49cef..44ac20e 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -1,8 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
+import AST.Accum
import AST.Types
import Data
@@ -11,16 +13,18 @@ type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
D1 (TEither a b) = TEither (D1 a) (D1 b)
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
+ D2 (TEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TMaybe (TArr n (D2 t))
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
@@ -40,12 +44,13 @@ type family D2E env where
type family D2AcE env where
D2AcE '[] = '[]
- D2AcE (t : env) = TAccum t : D2AcE env
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
@@ -55,27 +60,34 @@ d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env
+d2M :: STy t -> SMTy (D2 t)
+d2M STNil = SMTNil
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
+d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STMaybe t) = SMTMaybe (d2M t)
+d2M (STArr n t) = SMTArr n (d2M t)
+d2M (STScal t) = case t of
+ STI32 -> SMTNil
+ STI64 -> SMTNil
+ STF32 -> SMTScal STF32
+ STF64 -> SMTScal STF64
+ STBool -> SMTNil
+d2M STAccum{} = error "Accumulators not allowed in input program"
+
d2 :: STy t -> STy (D2 t)
-d2 STNil = STNil
-d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
-d2 (STMaybe t) = STMaybe (d2 t)
-d2 (STArr n t) = STMaybe (STArr n (d2 t))
-d2 (STScal t) = case t of
- STI32 -> STNil
- STI64 -> STNil
- STF32 -> STScal STF32
- STF64 -> STScal STF64
- STBool -> STNil
-d2 STAccum{} = error "Accumulators not allowed in input program"
+d2 = fromSMTy . d2M
+
+d2eM :: SList STy env -> SList SMTy (D2E env)
+d2eM SNil = SNil
+d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts
d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (t `SCons` ts) = d2 t `SCons` d2e ts
+d2e = slistMap fromSMTy . d2eM
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
-d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts
+d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts
data CHADConfig = CHADConfig
@@ -85,6 +97,8 @@ data CHADConfig = CHADConfig
chcCaseArrayAccum :: Bool
, -- | Introduce top-level arguments containing arrays in accumulator mode.
chcArgArrayAccum :: Bool
+ , -- | Place with-blocks around array variable scopes, and redirect accumulations there.
+ chcSmartWith :: Bool
}
deriving (Show)
@@ -93,12 +107,14 @@ defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
, chcArgArrayAccum = False
+ , chcSmartWith = False
}
chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
, chcCaseArrayAccum = True
- , chcArgArrayAccum = True }
+ , chcArgArrayAccum = True
+ , chcSmartWith = True }
------------------------------------ LEMMAS ------------------------------------
@@ -106,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl
+
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index f843206..888fed4 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -19,24 +19,25 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
- STPair t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal
STEither t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
+ STLEither t1 t2 -> case (primal, der) of
+ (_, Nothing) -> Nothing
+ (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
+ (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
+ _ -> error "Primal and cotangent disagree on LEither alternative"
STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t -> case der of
- Nothing -> arrayMap (zeroTan t) primal
- Just d
- | arrayShape primal == arrayShape d ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
+ STArr _ t
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
diff --git a/src/Compile.hs b/src/Compile.hs
index b4261ca..a5c4fb7 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -6,8 +6,9 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
-module Compile (compile, debugCSource, debugRefc, emitChecks) where
+module Compile (compile) where
import Control.Applicative (empty)
import Control.Monad (forM_, when, replicateM)
@@ -21,6 +22,7 @@ import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
import Data.Functor.Product (Product)
+import Data.IORef
import Data.List (foldl1', intersperse, intercalate)
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
@@ -34,6 +36,8 @@ import GHC.Num (integerFromWord#)
import GHC.Ptr (Ptr(..))
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
+import System.IO.Error (mkIOError, userErrorType)
+import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding ((^))
import qualified Prelude
@@ -41,7 +45,7 @@ import qualified Prelude
import Array
import AST
import AST.Pretty (ppSTy, ppExpr)
-import qualified CHAD.Types as CHAD
+import AST.Sparse.Types (isDense)
import Compile.Exec
import Data
import Interpreter.Rep
@@ -69,22 +73,25 @@ emitChecks :: Bool; emitChecks = toEnum 0
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
- let source = compileToString env expr
+ codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
+
+ let (source, offsets) = compileToString codeID env expr
when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
- when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
- lib <- buildKernel source ["kernel"]
+ when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
+ lib <- buildKernel source "kernel"
- let arg_metrics = reverse (unSList metricsSTy env)
- (arg_offsets, result_offset) = computeStructOffsets arg_metrics
- result_type = typeOf expr
+ let result_type = typeOf expr
result_size = sizeofSTy result_type
return $ \val -> do
- allocaBytes (result_offset + result_size) $ \ptr -> do
- let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets
+ allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
serialiseArguments args ptr $ do
- callKernelFun "kernel" lib ptr
- deserialise result_type ptr result_offset
+ callKernelFun lib ptr
+ ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
+ when (ok /= 1) $
+ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
+ deserialise result_type ptr (koResultOffset offsets)
where
serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
@@ -215,6 +222,7 @@ genStructName = \t -> "ty_" ++ gen t where
gen STNil = "n"
gen (STPair a b) = 'P' : gen a ++ gen b
gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
gen (STMaybe t) = 'M' : gen t
gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
gen (STScal st) = case st of
@@ -223,11 +231,14 @@ genStructName = \t -> "ty_" ++ gen t where
STF32 -> "f"
STF64 -> "d"
STBool -> "b"
- gen (STAccum t) = 'C' : gen t
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
-- types in the C translation.
+--
+-- For accumulation it is important that for struct representations of monoid
+-- types, the all-zero-bytes value corresponds to the zero value of that type.
genStruct :: String -> STy t -> [StructDecl]
genStruct name topty = case topty of
STNil ->
@@ -236,16 +247,20 @@ genStruct name topty = case topty of
[StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
STEither a b -> -- 0 -> l, 1 -> r
[StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
STMaybe t -> -- 0 -> nothing, 1 -> just
[StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
+ -- TODO: put shape in the main struct, not the buffer; it's constant, after all
+ -- TODO: no buffer if n = 0
[StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
,StructDecl name (name ++ "_buf *buf;") com]
STScal _ ->
[]
STAccum t ->
- [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") ""
+ [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
,StructDecl name (name ++ "_buf *buf;") com]
where
com = ppSTy 0 topty
@@ -268,18 +283,19 @@ genStructs ty = do
STNil -> pure ()
STPair a b -> genStructs a >> genStructs b
STEither a b -> genStructs a >> genStructs b
+ STLEither a b -> genStructs a >> genStructs b
STMaybe t -> genStructs t
STArr _ t -> genStructs t
STScal _ -> pure ()
- STAccum t -> genStructs (CHAD.d2 t)
+ STAccum t -> genStructs (fromSMTy t)
tell (BList (genStruct name ty))
-genAllStructs :: Foldable t => t Ty -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty
+genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
+genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty
data CompState = CompState
- { csStructs :: Set Ty
+ { csStructs :: Set (Some STy)
, csTopLevelDecls :: Bag String
, csStmts :: Bag Stmt
, csNextId :: Int }
@@ -322,7 +338,7 @@ scope m = do
emitStruct :: STy t -> CompM String
emitStruct ty = CompM $ do
- modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) }
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
return (genStructName ty)
emitTLD :: String -> CompM ()
@@ -331,63 +347,94 @@ emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <>
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
-compileToString :: SList STy env -> Ex env t -> String
-compileToString env expr =
+data KernelOffsets = KernelOffsets
+ { koArgOffsets :: [Int] -- ^ the function arguments
+ , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
+ , koResultOffset :: Int -- ^ the function result
+ }
+
+compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
+compileToString codeID env expr =
let args = nameEnv env
(res, s) = runCompM (compile' args expr)
- structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
+ structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
(arg_pairs, arg_metrics) =
unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
(slistZip env args))
- (arg_offsets, result_offset') = computeStructOffsets arg_metrics
- result_offset = align (alignmentSTy (typeOf expr)) result_offset'
- in ($ "") $ compose
+ (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
+ result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
+
+ offsets = KernelOffsets
+ { koArgOffsets = arg_offsets
+ , koOkResOffset = okres_offset
+ , koResultOffset = result_offset }
+ in (,offsets) . ($ "") $ compose
[showString "#include <stdio.h>\n"
,showString "#include <stdint.h>\n"
+ ,showString "#include <stdbool.h>\n"
,showString "#include <inttypes.h>\n"
,showString "#include <stdlib.h>\n"
,showString "#include <string.h>\n"
,showString "#include <math.h>\n\n"
+ -- PRint-tag
+ ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
+
,compose [printStructDecl sd . showString "\n" | sd <- structs]
,showString "\n"
- ,showString "static void* malloc_instr(size_t n) {\n"
+
+ -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
+ ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
,showString " void *ptr = malloc(n);\n"
- ,if debugAllocs then showString "printf(\"[chad-kernel] malloc(%zu) -> %p\\n\", n, ptr);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
,showString " return ptr;\n"
,showString "}\n"
- ,showString "static void* calloc_instr(size_t n) {\n"
+ ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
,showString " void *ptr = calloc(n, 1);\n"
- ,if debugAllocs then showString "printf(\"[chad-kernel] calloc(%zu) -> %p\\n\", n, ptr);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
,showString " return ptr;\n"
,showString "}\n"
+ ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
,showString "static void free_instr(void *ptr) {\n"
- ,if debugAllocs then showString "printf(\"[chad-kernel] free(%p)\\n\", ptr);\n"
+ ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
else id
,showString " free(ptr);\n"
,showString "}\n\n"
+
,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
+
,showString $
- "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++
- intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
+ "static bool typed_kernel(" ++
+ repSTy (typeOf expr) ++ " *output" ++
+ concatMap (", " ++)
+ (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
") {\n"
,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
- ,showString " return " . printCExpr 0 res . showString ";\n}\n\n"
+ ,showString " *output = " . printCExpr 0 res . showString ";\n"
+ ,showString " return true;\n"
+ ,showString "}\n\n"
+
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"
- ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Start\\n\");\n"
+ ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
else id
- ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
- concat (map (\((arg, typ), off, idx) ->
- "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
- ++ (if idx < length arg_pairs - 1 then "," else "")
- ++ " // " ++ arg)
- (zip3 arg_pairs arg_offsets [0::Int ..])) ++
+ ,showString $ " const bool success = typed_kernel(" ++
+ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
+ concat (map (\((arg, typ), off) ->
+ ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
+ ++ " /* " ++ arg ++ " */")
+ (zip arg_pairs arg_offsets)) ++
"\n );\n"
- ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Return\\n\");\n"
+ ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
else id
,showString "}\n"]
@@ -412,11 +459,20 @@ serialise topty topval ptr off k =
serialise a x ptr off $
serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
(STEither a _, Left x) -> do
- pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b})
+ pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
serialise a x ptr (off + alignmentSTy topty) k
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
(STMaybe _, Nothing) -> do
pokeByteOff ptr off (0 :: Word8)
k
@@ -460,9 +516,16 @@ deserialise topty ptr off =
return (x, y)
STEither a b -> do
tag <- peekByteOff @Word8 ptr off
- if tag == 0 -- alignment of (a + b) is alignment of (union {a b})
+ if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
STMaybe t -> do
tag <- peekByteOff @Word8 ptr off
if tag == 0
@@ -507,6 +570,10 @@ metricsSTy (STEither a b) =
let (a1, s1) = metricsSTy a
(a2, s2) = metricsSTy b
in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
@@ -517,7 +584,7 @@ metricsSTy (STScal sty) = case sty of
STF32 -> (4, 4)
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
-metricsSTy (STAccum t) = metricsSTy t
+metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -647,6 +714,39 @@ compile' env = \case
<> pure (SAsg retvar e3))))
return (CELit retvar)
+ ELNil _ t1 t2 -> do
+ name <- emitStruct (STLEither t1 t2)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ ELInl _ t e -> do
+ name <- emitStruct (STLEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
+
+ ELInr _ t e -> do
+ name <- emitStruct (STLEither t (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
+
+ ELCase _ e a b c -> do
+ let STLEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
+ (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
+ ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2 <> pure (SAsg retvar e2))
+ (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
+ (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
+ (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
+ return (CELit retvar)
+
EConstArr _ n t (Array sh vec) -> do
strname <- emitStruct (STArr n (STScal t))
tldname <- genName' "carraybuf"
@@ -696,8 +796,7 @@ compile' env = \case
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
- resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname))
- [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
@@ -743,8 +842,7 @@ compile' env = \case
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname))
- [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
@@ -795,8 +893,7 @@ compile' env = \case
resname <- allocArray "repl1i" Malloc "rep" (SS n) t
(Just (CEBinop (CELit shszname) "*" (CELit lenname)))
- ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
- ++ [CELit lenname])
+ (compileArrShapeComponents n argname ++ [CELit lenname])
ivar <- genName' "i"
jvar <- genName' "j"
@@ -840,8 +937,8 @@ compile' env = \case
emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
(CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
(pure $ SVerbatim $
- "fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++
- arrname ++ ".buf); abort();")
+ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
+ arrname ++ ".buf); return false;")
mempty
resname <- genName' "ixres"
@@ -881,6 +978,8 @@ compile' env = \case
maybe (return ()) ($ name2) mfun2
return (CELit name)
+ ERecompute _ e -> compile' env e
+
EWith _ t e1 e2 -> do
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
@@ -888,203 +987,185 @@ compile' env = \case
zeroRefcountCheck (typeOf e1) "with" name1
emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
- mcopy <- copyForWriting (CHAD.d2 t) name1
+ mcopy <- copyForWriting t name1
accname <- genName' "accum"
emit $ SVarDecl False actyname accname
- (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (CHAD.d2 t)))])])
+ (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
e2' <- compile' (Const accname `SCons` env) e2
resname <- genName' "acret"
- emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac"))
+ emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
- rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))
+ rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
- EAccum _ t prj eidx eval eacc -> do
- nameidx <- compileAssign "acidx" env eidx
- nameval <- compileAssign "acval" env eval
-
- -- Generate the variable manually because this one has to be non-const.
- eacc' <- compile' env eacc
- nameacc <- genName' "acac"
- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
-
- let -- Expects a variable reference to a value of type @D2 a@.
- setZero :: STy a -> String -> CompM ()
- setZero STNil _ = return ()
- setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b))
- setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b))
- setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a)
- setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a))
- setZero (STScal sty) v = case sty of
- STI32 -> return () -- Nil
- STI64 -> return () -- Nil
- STF32 -> emit $ SAsg v (CELit "0.0f")
- STF64 -> emit $ SAsg v (CELit "0.0")
- STBool -> return () -- Nil
- setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported"
-
- initD2Pair :: STy a -> STy b -> String -> CompM ()
- initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b))
- ((), stmts1) <- scope $ setZero a (v++".j.a")
- ((), stmts2) <- scope $ setZero b (v++".j.b")
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2)
- mempty
-
- initD2Either :: STy a -> STy b -> String -> Either () () -> CompM ()
- initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b))
- ((), stmts) <- case side of
- Left () -> scope $ setZero a (v++".j.l")
- Right () -> scope $ setZero b (v++".j.r")
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
- mempty
-
- initD2Maybe :: STy a -> String -> CompM ()
- initD2Maybe a v = do -- Maybe (D2 a)
- ((), stmts) <- scope $ setZero a (v++".j")
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
- mempty
-
- -- mind: this has to traverse the D2 of these things, and it also has to
- -- initialise data structures that are still sparse in the accumulator.
- let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String
- accumRef _ SAPHere v _ = pure v
- accumRef (STPair ta tb) (SAPFst prj') v i = do
- initD2Pair ta tb v
- accumRef ta prj' (v++".j.a") i
- accumRef (STPair ta tb) (SAPSnd prj') v i = do
- initD2Pair ta tb v
- accumRef tb prj' (v++".j.b") i
- accumRef (STEither ta tb) (SAPLeft prj') v i = do
- initD2Either ta tb v (Left ())
- accumRef ta prj' (v++".j.l") i
- accumRef (STEither ta tb) (SAPRight prj') v i = do
- initD2Either ta tb v (Right ())
- accumRef tb prj' (v++".j.r") i
- accumRef (STMaybe tj) (SAPJust prj') v i = do
- initD2Maybe tj v
- accumRef tj prj' (v++".j") i
- accumRef (STArr n t') (SAPArrIdx prj' _) v i = do
- (newarrName, newarrStmts) <- scope $ allocArray "accumRef" Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b"))
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1"))
- <> newarrStmts
- <> pure (SAsg (v++".j") (CELit newarrName)))
- mempty
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
+ let -- Add a value (s) into an existing accumulation value (d). If a sparse
+ -- component of d is encountered, s is copied there.
+ add :: SMTy a -> String -> String -> CompM ()
+ add SMTNil _ _ = return ()
+ add (SMTPair t1 t2) d s = do
+ add t1 (d++".a") (s++".a")
+ add t2 (d++".b") (s++".b")
+ add (SMTLEither t1 t2) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
+ ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
+ ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ ((if emitChecks
+ then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
+ "&&"
+ (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
+ "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
+ "return false;")
+ mempty)
+ else mempty)
+ -- note: s may have tag 0
+ <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ stmts1
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
+ stmts2 mempty))))
+ add (SMTMaybe t1) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
+ ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
+ add (SMTArr n t1) d s = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ [0 .. fromSNat n - 1] $ \j -> do
+ emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
+ "!="
+ (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
+ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
+ d ++ ".buf" ++
+ concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ ", " ++ s ++ ".buf" ++
+ concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ "); " ++
+ "return false;")
+ mempty
+ shsizename <- genName' "acshsz"
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
+ ivar <- genName' "i"
+ ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+ stmts1
+ add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
+
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
- "fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".j.buf" ++
- concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
"); " ++
- "abort();")
+ "return false;")
mempty
- accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b")
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
- -- mind: this has to add the D2 of these things, and it also has to
- -- initialise data structures that are still sparse in the accumulator.
- let add :: STy a -> String -> String -> CompM ()
- add STNil _ _ = return ()
- add (STPair t1 t2) d s = do
- ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a")
- ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (stmts1 <> stmts2)
- mempty))
- add (STEither t1 t2) d s = do
- ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l")
- ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (pure (SAsg (d++".j.tag") (CELit (s++".j.tag")))
- <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0"))
- stmts1 stmts2))
- mempty))
- add (STMaybe t1) d s = do
- ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (pure (SAsg (d++".tag") (CELit "1")) <> stmts1)
- mempty))
- add (STArr n t1) d s = do
- shsizename <- genName' "acshsz"
- ivar <- genName' "i"
- ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]")
- ((), stmtsDecr) <- scope $ incrementVarAlways "accumarr" Decrement (STArr n (CHAD.d2 t1)) (s++".j")
- emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s)))
- (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")))
- -- TODO: emit check here for the source being either equal in shape to the destination
- <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
- stmts1)
- <> stmtsDecr)))
- mempty
- add (STScal sty) d s = case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STBool -> return ()
- add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- dest <- accumRef t prj (nameacc++".buf->ac") nameidx
- add (acPrjTy prj t) dest nameval
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
emit $ SVerbatim $ "// compile EAccum end"
+ incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
+
return $ CEStruct (repSTy STNil) []
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
EError _ t s -> do
let padleft len c s' = replicate (len - length s) c ++ s'
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
| ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
| otherwise -> [c]
- emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); exit(1);"
+ emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
case t of
STScal _ -> return (CELit "0")
_ -> do
name <- emitStruct t
return $ CEStruct name []
- EZero{} -> error "Compile: monoid operations should have been eliminated"
- EPlus{} -> error "Compile: monoid operations should have been eliminated"
- EOneHot{} -> error "Compile: monoid operations should have been eliminated"
+ EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EIdx1{} -> error "Compile: not implemented: EIdx1"
compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
compileAssign prefix env e = do
e' <- compile' env e
- name <- genName' prefix
- emit $ SVarDecl True (repSTy (typeOf e)) name e'
- return name
+ case e' of
+ CELit name -> return name
+ _ -> do
+ name <- genName' prefix
+ emit $ SVarDecl True (repSTy (typeOf e)) name e'
+ return name
data Increment = Increment | Decrement
deriving (Show)
@@ -1103,6 +1184,7 @@ data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array
| ATNoop -- ^ don't do anything here
| ATProj String ArrayTree -- ^ descend one field deeper
| ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
+ | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
| ATBoth ArrayTree ArrayTree -- ^ do both these paths
smartATProj :: String -> ArrayTree -> ArrayTree
@@ -1113,6 +1195,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
smartATCondTag ATNoop ATNoop = ATNoop
smartATCondTag t t' = ATCondTag t t'
+smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
+smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
+smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
+
smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
smartATBoth ATNoop t = t
smartATBoth t ATNoop = t
@@ -1124,6 +1210,9 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
@@ -1135,19 +1224,19 @@ incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
Increment -> do
emit $ SVerbatim (path ++ ".buf->refc++;")
when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
Decrement -> do
case incrementVar (marker++".elt") Decrement eltty of
Nothing ->
if debugRefc
then do
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
else do
emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
Just f -> do
when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
shszvar <- genName' "frshsz"
ivar <- genName' "i"
((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
@@ -1163,6 +1252,15 @@ incrementVar' marker inc path (ATCondTag t1 t2) = do
((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
+incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
+ stmts2
+ (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
+ stmts3
+ stmts1))
incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
toLinearIdx :: SNat n -> String -> String -> CExpr
@@ -1204,7 +1302,7 @@ allocArray marker method nameBase rank eltty mshsz shape = do
emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
return arrname
compileShapeQuery :: SNat n -> String -> CExpr
@@ -1216,10 +1314,12 @@ compileShapeQuery (SS n) var =
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize SZ _ = CELit "1"
-compileArrShapeSize n var =
- foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]")
- | i <- [0 .. fromSNat n - 1]]
+compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeComponents :: SNat n -> String -> [CExpr]
+compileArrShapeComponents n var =
+ [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
indexTupleComponents :: SNat n -> String -> [CExpr]
indexTupleComponents = \n var -> map CELit (toList (go n var))
@@ -1268,6 +1368,7 @@ compileOpGeneral op e1 = do
OLog STF32 -> unary "logf"
OLog STF64 -> unary "log"
OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
compileOpPair op e1 e2 = do
@@ -1281,6 +1382,7 @@ compileOpPair op e1 e2 = do
OAnd -> binary "&&"
OOr -> binary "||"
OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
_ -> error "compileOpPair: got unary operator"
-- | Bool: whether to ensure that the literal itself already has the appropriate type
@@ -1304,14 +1406,13 @@ compileExtremum nameBase opName operator env e = do
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname))
- [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
(CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
- emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }"
+ emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
ivar <- genName' "i"
jvar <- genName' "j"
@@ -1332,47 +1433,47 @@ compileExtremum nameBase opName operator env e = do
-- | If this returns Nothing, there was nothing to copy because making a simple
-- value copy in C already makes it suitable to write to.
-copyForWriting :: STy t -> String -> CompM (Maybe CExpr)
+copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
copyForWriting topty var = case topty of
- STNil -> return Nothing
+ SMTNil -> return Nothing
- STPair a b -> do
+ SMTPair a b -> do
e1 <- copyForWriting a (var ++ ".a")
e2 <- copyForWriting b (var ++ ".b")
case (e1, e2) of
(Nothing, Nothing) -> return Nothing
- _ -> return $ Just $ CEStruct (repSTy topty)
+ _ -> return $ Just $ CEStruct toptyname
[("a", fromMaybe (CELit (var++".a")) e1)
,("b", fromMaybe (CELit (var++".b")) e2)]
- STEither a b -> do
+ SMTLEither a b -> do
(e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
(e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
case (e1, e2) of
(Nothing, Nothing) -> return Nothing
_ -> do
name <- genName
- emit $ SVarDeclUninit (repSTy topty) name
+ emit $ SVarDeclUninit toptyname name
emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
(stmts1
- <> pure (SAsg name (CEStruct (repSTy topty)
+ <> pure (SAsg name (CEStruct toptyname
[("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
(stmts2
- <> pure (SAsg name (CEStruct (repSTy topty)
+ <> pure (SAsg name (CEStruct toptyname
[("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
return (Just (CELit name))
- STMaybe t -> do
+ SMTMaybe t -> do
(e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
case e1 of
Nothing -> return Nothing
Just e1' -> do
name <- genName
- emit $ SVarDeclUninit (repSTy topty) name
+ emit $ SVarDeclUninit toptyname name
emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
- (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))
+ (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
(stmts1
- <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')])))
+ <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
return (Just (CELit name))
-- If there are no nested arrays, we know that a refcount of 1 means that the
@@ -1380,26 +1481,26 @@ copyForWriting topty var = case topty of
-- nesting we'd have to check the refcounts of all the nested arrays _too_;
-- let's not do that. Furthermore, no sub-arrays means that the whole thing
-- is flat, and we can just memcpy if necessary.
- STArr n t | not (hasArrays t) -> do
+ SMTArr n t | not (hasArrays (fromSMTy t)) -> do
name <- genName
shszname <- genName' "shsz"
- emit $ SVarDeclUninit (repSTy (STArr n t)) name
+ emit $ SVarDeclUninit toptyname name
when debugShapes $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
emit $ SVerbatim $
- "fprintf(stderr, \"[chad-kernel] with array " ++ shfmt ++ "\\n\"" ++
+ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
");"
emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
(pure (SAsg name (CELit var)))
(let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
in BList
[SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
- ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])])
+ ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
show shbytes ++ ");"
,SAsg (name ++ ".buf->refc") (CELit "1")
@@ -1407,26 +1508,26 @@ copyForWriting topty var = case topty of
printCExpr 0 databytes ");"])
return (Just (CELit name))
- STArr n t -> do
+ SMTArr n t -> do
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
name <- genName
- emit $ SVarDecl False (repSTy (STArr n t)) name
- (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])])
+ emit $ SVarDecl False toptyname name
+ (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
show shbytes ++ ");"
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-- put the arrays in variables to cut short the not-quite-var chain
dstvar <- genName' "cpydst"
- emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs"))
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
srcvar <- genName' "cpysrc"
- emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs"))
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
ivar <- genName' "i"
@@ -1441,9 +1542,10 @@ copyForWriting topty var = case topty of
return (Just (CELit name))
- STScal _ -> return Nothing
+ SMTScal _ -> return Nothing
- STAccum _ -> error "Compile: Nested accumulators not supported"
+ where
+ toptyname = repSTy (fromSMTy topty)
zeroRefcountCheck :: STy t -> String -> String -> CompM ()
zeroRefcountCheck toptyp opname topvar =
@@ -1462,6 +1564,14 @@ zeroRefcountCheck toptyp opname topvar =
go (STEither a b) path = do
(s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
go (STMaybe a) path = do
ss <- go a (path++".j")
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
@@ -1471,8 +1581,8 @@ zeroRefcountCheck toptyp opname topvar =
shszname <- genName' "shsz"
let s1 = SVerbatim $
"if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
- "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
- "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }"
+ "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
return (BList [s1, s2, s3])
@@ -1489,6 +1599,10 @@ zeroRefcountCheck toptyp opname topvar =
(Nothing, Just y') -> Just (mempty, y')
(Just x', Just y') -> Just (x', y')
+{-# NOINLINE uniqueIdGenRef #-}
+uniqueIdGenRef :: IORef Int
+uniqueIdGenRef = unsafePerformIO $ newIORef 1
+
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index 9b29486..9b9fb15 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -4,12 +4,13 @@ module Compile.Exec (
KernelLib,
buildKernel,
callKernelFun,
+
+ -- * misc
+ lineNumbers,
) where
import Control.Monad (when)
import Data.IORef
-import qualified Data.Map.Strict as Map
-import Data.Map.Strict (Map)
import Foreign (Ptr)
import Foreign.Ptr (FunPtr)
import System.Directory (removeDirectoryRecursive)
@@ -27,10 +28,10 @@ debug :: Bool
debug = False
-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)
-data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ()))))
+data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ())))
-buildKernel :: String -> [String] -> IO KernelLib
-buildKernel csource funnames = do
+buildKernel :: String -> String -> IO KernelLib
+buildKernel csource funname = do
template <- (++ "/tmp.chad.") <$> getTempDir
path <- mkdtemp template
@@ -40,7 +41,8 @@ buildKernel csource funnames = do
,"-std=c99", "-x", "c"
,"-o", outso, "-"
,"-Wall", "-Wextra"
- ,"-Wno-unused-variable", "-Wno-unused-parameter", "-Wno-unused-function"]
+ ,"-Wno-unused-variable", "-Wno-unused-but-set-variable"
+ ,"-Wno-unused-parameter", "-Wno-unused-function"]
(ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource
-- Print the source before the GCC output.
@@ -65,8 +67,7 @@ buildKernel csource funnames = do
removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now
- ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames]
- ref <- newIORef ptrs
+ ref <- newIORef =<< dlsym dl funname
_ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1))
when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"
dlclose dl)
@@ -77,10 +78,10 @@ foreign import ccall "dynamic"
-- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive
{-# NOINLINE callKernelFun #-}
-callKernelFun :: String -> KernelLib -> Ptr () -> IO ()
-callKernelFun key (KernelLib ref) arg = do
- mp <- readIORef ref
- wrapKernelFun (mp Map.! key) arg
+callKernelFun :: KernelLib -> Ptr () -> IO ()
+callKernelFun (KernelLib ref) arg = do
+ ptr <- readIORef ref
+ wrapKernelFun ptr arg
getTempDir :: IO FilePath
getTempDir =
diff --git a/src/Data.hs b/src/Data.hs
index e7b3148..e6978c8 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -8,10 +8,13 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module Data (module Data, (:~:)(Refl)) where
+module Data (module Data, (:~:)(Refl), If) where
import Data.Functor.Product
+import Data.GADT.Compare
+import Data.GADT.Show
import Data.Some
+import Data.Type.Bool (If)
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
@@ -73,10 +76,15 @@ data SNat n where
SS :: SNat n -> SNat (S n)
deriving instance Show (SNat n)
-instance TestEquality SNat where
- testEquality SZ SZ = Just Refl
- testEquality (SS n) (SS n') | Just Refl <- testEquality n n' = Just Refl
- testEquality _ _ = Nothing
+instance GCompare SNat where
+ gcompare SZ SZ = GEQ
+ gcompare SZ _ = GLT
+ gcompare _ SZ = GGT
+ gcompare (SS n) (SS n') = gorderingLift1 (gcompare n n')
+
+instance TestEquality SNat where testEquality = geq
+instance GEq SNat where geq = defaultGeq
+instance GShow SNat where gshowsPrec = defaultGshowsPrec
fromSNat :: SNat n -> Int
fromSNat SZ = 0
@@ -90,10 +98,6 @@ reSNat :: Nat -> Some SNat
reSNat Z = Some SZ
reSNat (S n) | Some n' <- reSNat n = Some (SS n')
-fromNat :: Nat -> Int
-fromNat Z = 0
-fromNat (S m) = succ (fromNat m)
-
class KnownNat n where knownNat :: SNat n
instance KnownNat Z where knownNat = SZ
instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
@@ -155,6 +159,18 @@ vecInit (x :< xs@(_ :< _)) = x :< vecInit xs
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = unsafeCoerce Refl
+gorderingLift1 :: GOrdering a a' -> GOrdering (f a) (f a')
+gorderingLift1 GLT = GLT
+gorderingLift1 GGT = GGT
+gorderingLift1 GEQ = GEQ
+
+gorderingLift2 :: GOrdering a a' -> GOrdering b b' -> GOrdering (f a b) (f a' b')
+gorderingLift2 GLT _ = GLT
+gorderingLift2 GGT _ = GGT
+gorderingLift2 GEQ GLT = GLT
+gorderingLift2 GEQ GGT = GGT
+gorderingLift2 GEQ GEQ = GEQ
+
data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t]
deriving (Show, Functor, Foldable, Traversable)
@@ -169,3 +185,8 @@ instance Applicative Bag where
instance Semigroup (Bag t) where (<>) = BTwo
instance Monoid (Bag t) where mempty = BNone
+
+data SBool b where
+ SF :: SBool False
+ ST :: SBool True
+deriving instance Show (SBool b)
diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs
new file mode 100644
index 0000000..2712b08
--- /dev/null
+++ b/src/Data/VarMap.hs
@@ -0,0 +1,119 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.VarMap (
+ VarMap,
+ empty,
+ insert,
+ delete,
+ TypedIdx(..),
+ lookup,
+ disjointUnion,
+ sink1,
+ unsink1,
+ subMap,
+ superMap,
+) where
+
+import Prelude hiding (lookup)
+
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Data.Maybe (mapMaybe)
+import Data.Some
+import qualified Data.Vector.Storable as VS
+import Unsafe.Coerce
+
+import AST.Env
+import AST.Types
+import AST.Weaken
+
+
+type role VarMap _ nominal -- ensure that 'env' is not phantom
+data VarMap k (env :: [Ty]) =
+ VarMap Int -- ^ Global offset; must be added to any value in the map in order to get the proper index
+ Int -- ^ Time since last cleanup
+ (Map k (Some STy, Int))
+deriving instance Show k => Show (VarMap k env)
+
+empty :: VarMap k env
+empty = VarMap 0 0 Map.empty
+
+insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env
+insert k ty idx (VarMap off interval mp) =
+ maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp)
+
+delete :: Ord k => k -> VarMap k env -> VarMap k env
+delete k (VarMap off interval mp) =
+ maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp)
+
+data TypedIdx env t = TypedIdx (STy t) (Idx env t)
+ deriving (Show)
+
+lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env))
+lookup k (VarMap off _ mp) = do
+ (Some ty, i) <- Map.lookup k mp
+ idx <- unsafeInt2idx (i + off)
+ return (Some (TypedIdx ty idx))
+
+disjointUnion :: Ord k => VarMap k env -> VarMap k env -> VarMap k env
+disjointUnion (VarMap off1 cl1 m1) (VarMap off2 cl2 m2) | off1 == off2 =
+ VarMap off1 (min cl1 cl2) (Map.unionWith (error "VarMap.disjointUnion: overlapping keys") m1 m2)
+disjointUnion vm1 vm2 = disjointUnion (cleanup vm1) (cleanup vm2)
+
+sink1 :: VarMap k env -> VarMap k (t : env)
+sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp
+
+unsink1 :: VarMap k (t : env) -> VarMap k env
+unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp
+
+subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env'
+subMap subenv =
+ let bools = let loop :: Subenv env env' -> [Bool]
+ loop SETop = []
+ loop (SEYesR sub) = True : loop sub
+ loop (SENo sub) = False : loop sub
+ in VS.fromList $ loop subenv
+ newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools
+ modify off (k, (ty, i))
+ | i + off < 0 = Nothing
+ | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map"
+ | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off)))
+ | otherwise = Nothing
+ in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp)
+
+superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env
+superMap subenv =
+ let loop :: Subenv env env' -> Int -> [Int]
+ loop SETop _ = []
+ loop (SEYesR sub) i = i : loop sub (i+1)
+ loop (SENo sub) i = loop sub (i+1)
+
+ newIndices = VS.fromList $ loop subenv 0
+ modify off (k, (ty, i))
+ | i + off < 0 = Nothing
+ | i + off >= VS.length newIndices = error "VarMap.superMap: found negative indices in map"
+ | otherwise = let j = newIndices VS.! (i + off)
+ in if j == -1 then Nothing else Just (k, (ty, j))
+
+ in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp)
+
+maybeCleanup :: VarMap k env -> VarMap k env
+maybeCleanup vm@(VarMap _ interval mp)
+ | let sz = Map.size mp
+ , sz > 0, 2 * interval >= 3 * sz
+ = cleanup vm
+maybeCleanup vm = vm
+
+cleanup :: VarMap k env -> VarMap k env
+cleanup (VarMap off _ mp) = VarMap 0 0 (Map.mapMaybe (\(t, i) -> if i + off >= 0 then Just (t, i + off) else Nothing) mp)
+
+unsafeInt2idx :: Int -> Maybe (Idx env t)
+unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n)
+ where
+ go :: Int -> Idx env t
+ go 0 = unsafeCoerce IZ
+ go n = unsafeCoerce (IS (go (n-1)))
diff --git a/src/Example.hs b/src/Example.hs
index 2c710a1..b320ead 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -5,11 +5,14 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
+
+{-# OPTIONS -Wno-unused-imports #-}
module Example where
import Array
import AST
import AST.Pretty
+import AST.UnMonoid
import CHAD
import CHAD.Top
import ForwardAD
@@ -30,11 +33,6 @@ bin op a b = EOp ext op (EPair ext a b)
senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
-descr1 :: Storage a -> Storage b
- -> Descr [TScal TF32, TScal TF32] [b, a]
-descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
- where t = STScal STF32
-
-- x y |- x * y + x
--
-- let x3 = (x1, x2)
@@ -82,25 +80,12 @@ ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 = fromNamed $ lambda #x $ lambda #y $ body $
if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
-senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
-senv5 = knownEnv
-
-descr5 :: Storage a -> Storage b
- -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
-descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b)
-
-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 = fromNamed $ lambda #x $ lambda #y $ body $
case_ #x (#a :-> #a * #y)
(#b :-> #b * (#y + 1))
-senv6 :: SList STy [TScal TI64, TScal TF32]
-senv6 = knownEnv
-
-descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
-descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-
-- x:R n:I |- let a = unit x
-- b = build1 n (\i. let c = idx0 a in c * c)
-- in idx0 (b ! 3)
@@ -110,12 +95,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $
let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
#b ! pair nil 3
-senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
-senv7 = knownEnv
-
-descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
-descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-
-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
-- |- let p1 = snd ps
@@ -182,9 +161,8 @@ neuralGo =
simplifyN 20 $
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
- (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
- (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
- _ -> undefined
+ (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
+ (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index 12bbd98..206e534 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -31,10 +31,10 @@ import Language
-- <https://tomsmeding.com/f/master.pdf>
--
-- The 'wrong' argument, when set to True, changes the objective function to
--- one with a bug that makes a certain `build` result unused. This triggers
+-- one with a bug that makes a certain `build` result unused. This
-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
-- dense, even though it may be a zero (i.e. empty). The "unused" test in
--- test/Main.hs tries to isolate this test, but the wrong version of
+-- test/Main.hs tries to isolate this case, but the wrong version of
-- gmmObjective is here to check (after that bug is fixed) whether it really
-- fixes the original bug.
gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index af35f91..b353def 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -26,6 +26,7 @@ type family Tan t where
Tan TNil = TNil
Tan (TPair a b) = TPair (Tan a) (Tan b)
Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
@@ -45,6 +46,7 @@ tanty :: STy t -> STy (Tan t)
tanty STNil = STNil
tanty (STPair a b) = STPair (tanty a) (tanty b)
tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
tanty (STMaybe t) = STMaybe (tanty t)
tanty (STArr n t) = STArr n (tanty t)
tanty (STScal t) = case t of
@@ -55,11 +57,18 @@ tanty (STScal t) = case t of
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
+
zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
zeroTan (STMaybe _) Nothing = Nothing
zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
zeroTan (STArr _ t) x = fmap (zeroTan t) x
@@ -75,6 +84,9 @@ tanScalars STNil () = []
tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
tanScalars (STEither a _) (Left x) = tanScalars a x
tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanScalars (STMaybe _) Nothing = []
tanScalars (STMaybe t) (Just x) = tanScalars t x
tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
@@ -98,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) =
unzipDN (STEither a b) d = case d of
Left d1 -> bimap Left Left (unzipDN a d1)
Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
unzipDN (STMaybe t) d = case d of
Nothing -> (Nothing, Nothing)
Just d' -> bimap Just Just (unzipDN t d')
@@ -120,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of
(Left x', Left y') -> dotprodTan a x' y'
(Right x', Right y') -> dotprodTan b x' y'
_ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
dotprodTan (STMaybe t) x y = case (x, y) of
(Nothing, Nothing) -> 0.0
(Just x', Just y') -> dotprodTan t x' y'
@@ -165,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t)
dnConst STNil = const ()
dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
dnConst (STMaybe t) = fmap (dnConst t)
dnConst (STArr _ t) = arrayMap (dnConst t)
dnConst (STScal t) = case t of
@@ -188,6 +211,11 @@ dnOnehots (STEither t1 t2) e =
case e of
Left x -> \f -> Left (dnOnehots t1 x (f . Left))
Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnOnehots (STMaybe t) m =
case m of
Nothing -> \_ -> Nothing
@@ -223,7 +251,7 @@ data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (D
makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
makeFwdADArtifactInterp env expr =
let dexpr = dfwdDN expr
- in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
{-# NOINLINE makeFwdADArtifactCompile #-}
makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 9a95f81..3ab08af 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -86,6 +86,9 @@ dop = \case
OIDiv t -> scalTyCase t
(case t of {})
(EOp ext (OIDiv t))
+ OMod t -> scalTyCase t
+ (case t of {})
+ (EOp ext (OMod t))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
@@ -140,6 +143,10 @@ dfwdDN = \case
ENothing _ t -> ENothing ext (dn t)
EJust _ e -> EJust ext (dfwdDN e)
EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2)
+ ELInl _ t e -> ELInl ext (dn t) (dfwdDN e)
+ ELInr _ t e -> ELInr ext (dn t) (dfwdDN e)
+ ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c)
EConstArr _ n t x -> scalTyCase t
(emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
(EConstArr ext n t x))
@@ -178,10 +185,12 @@ dfwdDN = \case
ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
+ ERecompute _ e -> dfwdDN e
EError _ t s -> EError ext (dn t) s
EWith{} -> err_accum
EAccum{} -> err_accum
+ EDeepZero{} -> err_monoid
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
index fba92d0..dcacf5f 100644
--- a/src/ForwardAD/DualNumbers/Types.hs
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -12,6 +12,7 @@ type family DN t where
DN TNil = TNil
DN (TPair a b) = TPair (DN a) (DN b)
DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TLEither a b) = TLEither (DN a) (DN b)
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
@@ -31,6 +32,7 @@ dn :: STy t -> STy (DN t)
dn STNil = STNil
dn (STPair a b) = STPair (dn a) (dn b)
dn (STEither a b) = STEither (dn a) (dn b)
+dn (STLEither a b) = STLEither (dn a) (dn b)
dn (STMaybe t) = STMaybe (dn t)
dn (STArr n t) = STArr n (dn t)
dn (STScal t) = case t of
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index ddc3479..ffc2929 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,9 +21,11 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
+import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
import System.IO (hPutStrLn, stderr)
@@ -34,7 +36,7 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
-import CHAD.Types
+import AST.Sparse.Types
import Data
import Interpreter.Rep
@@ -48,35 +50,39 @@ runAcM (AcM m) = unsafePerformIO m
acmDebugLog :: String -> AcM s ()
acmDebugLog s = AcM (hPutStrLn stderr s)
+data V t = V (STy t) (Rep t)
+
interpret :: Ex '[] t -> Rep t
-interpret = interpretOpen False SNil
+interpret = interpretOpen False SNil SNil
-- | Bool: whether to trace execution with debug prints (very verbose)
-interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t
-interpretOpen prints env e =
+interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env venv e =
runAcM $
let ?depth = 0
?prints = prints
- in interpret' env e
+ in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e
-interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int)
+ => SList V env -> Ex env t -> AcM s (Rep t)
interpret' env e = do
+ let tenv = slistMap (\(V t _) -> t) env
let dep = ?depth
let lenlimit = max 20 (100 - dep)
let replace a b = map (\c -> if c == a then b else c)
let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
| otherwise = replace '\n' ' ' s
- when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e)
res <- let ?depth = dep + 1 in interpret'Rec env e
when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
return res
-interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t)
interpret'Rec env = \case
- EVar _ _ i -> case slistIdx env i of Value x -> return x
+ EVar _ _ i -> case slistIdx env i of V _ x -> return x
ELet _ a b -> do
x <- interpret' env a
- let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b
+ let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b
expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
@@ -84,18 +90,32 @@ interpret'Rec env = \case
ENil _ -> return ()
EInl _ _ e -> Left <$> interpret' env e
EInr _ _ e -> Right <$> interpret' env e
- ECase _ e a b -> interpret' env e >>= \case
- Left x -> interpret' (Value x `SCons` env) a
- Right y -> interpret' (Value y `SCons` env) b
+ ECase _ e a b ->
+ let STEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Left x -> interpret' (V t1 x `SCons` env) a
+ Right y -> interpret' (V t2 y `SCons` env) b
ENothing _ _ -> return Nothing
EJust _ e -> Just <$> interpret' env e
- EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
+ EMaybe _ a b e ->
+ let STMaybe t1 = typeOf e
+ in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
+ ELNil _ _ _ -> return Nothing
+ ELInl _ _ e -> Just . Left <$> interpret' env e
+ ELInr _ _ e -> Just . Right <$> interpret' env e
+ ELCase _ e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Nothing -> interpret' env a
+ Just (Left x) -> interpret' (V t1 x `SCons` env) b
+ Just (Right y) -> interpret' (V t2 y `SCons` env) c
EConstArr _ _ _ v -> return v
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
- arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
EFold1Inner _ _ a b c -> do
- let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
+ let t = typeOf b
+ let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
@@ -126,34 +146,39 @@ interpret'Rec env = \case
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
- EIdx _ a b
- | STArr n _ <- typeOf a
- -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EIdx _ a b ->
+ let STArr n _ = typeOf a
+ in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
- ECustom _ _ _ _ pr _ _ e1 e2 -> do
+ ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
e1' <- interpret' env e1
e2' <- interpret' env e2
- interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
+ interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
+ ERecompute _ e -> interpret' env e
EWith _ t e1 e2 -> do
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
- interpret' (Value accum `SCons` env) e2
- EAccum _ t p e1 e2 e3 -> do
+ interpret' (V (STAccum t) accum `SCons` env) e2
+ EAccum _ t p e1 sp e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t p accum idx val
- EZero _ t -> do
- return $ zeroD2 t
+ accumAddSparseD t p accum idx sp val
+ EZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ addD2s t a' b'
+ return $ addM t a' b'
EOneHot _ t p a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ onehotD2 p t a' b'
+ return $ onehotM p t a' b'
EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -174,6 +199,7 @@ interpretOp op arg = case op of
OExp st -> floatingIsFractional st $ exp arg
OLog st -> floatingIsFractional st $ log arg
OIDiv st -> integralIsIntegral st $ uncurry quot arg
+ OMod st -> integralIsIntegral st $ uncurry rem arg
where
styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
styIsEq STI32 = id
@@ -182,211 +208,161 @@ interpretOp op arg = case op of
styIsEq STF64 = id
styIsEq STBool = id
-zeroD2 :: STy t -> Rep (D2 t)
-zeroD2 typ = case typ of
- STNil -> ()
- STPair _ _ -> Nothing
- STEither _ _ -> Nothing
- STMaybe _ -> Nothing
- STArr _ _ -> Nothing
- STScal sty -> case sty of
- STI32 -> ()
- STI64 -> ()
+zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t
+zeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi))
+ SMTLEither _ _ -> Nothing
+ SMTMaybe _ -> Nothing
+ SMTArr _ t -> arrayMap (zeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
STF32 -> 0.0
STF64 -> 0.0
- STBool -> ()
- STAccum{} -> error "Zero of Accum"
-addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
-addD2s typ a b = case typ of
- STNil -> ()
- STPair t1 t2 -> case (a, b) of
- (Nothing, _) -> b
- (_, Nothing) -> a
- (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2)
- STEither t1 t2 -> case (a, b) of
- (Nothing, _) -> b
- (_, Nothing) -> a
- (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y))
- (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y))
- _ -> error "Plus of inconsistent Eithers"
- STMaybe t -> case (a, b) of
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
+addM :: SMTy t -> Rep t -> Rep t -> Rep t
+addM typ a b = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b))
+ SMTLEither t1 t2 -> case (a, b) of
(Nothing, _) -> b
(_, Nothing) -> a
- (Just x, Just y) -> Just (addD2s t x y)
- STArr _ t -> case (a, b) of
+ (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y))
+ (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y))
+ _ -> error "Plus of inconsistent LEithers"
+ SMTMaybe t -> case (a, b) of
(Nothing, _) -> b
(_, Nothing) -> a
- (Just x, Just y) ->
- let sh1 = arrayShape x
- sh2 = arrayShape y
- in if | shapeSize sh1 == 0 -> Just y
- | shapeSize sh2 == 0 -> Just x
- | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i))
- | otherwise -> error "Plus of inconsistently shaped arrays"
- STScal sty -> case sty of
- STI32 -> ()
- STI64 -> ()
- STF32 -> a + b
- STF64 -> a + b
- STBool -> ()
- STAccum{} -> error "Plus of Accum"
-
-onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a)
-onehotD2 SAPHere _ _ val = val
-onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b)
-onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val)
-onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val))
-onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
-onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
-onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
- Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
-
-withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
+ (Just x, Just y) -> Just (addM t x y)
+ SMTArr _ t ->
+ let sh1 = arrayShape a
+ sh2 = arrayShape b
+ in if | shapeSize sh1 == 0 -> b
+ | shapeSize sh2 == 0 -> a
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
+ SMTScal sty -> numericIsNum sty $ a + b
+
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
+onehotM SAPHere _ _ val = val
+onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
+onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
+onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val))
+onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val))
+onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val)
+onehotM (SAPArrIdx prj) (SMTArr n a) idx val =
+ runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx
+
+withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- accum <- newAcSparse t SAPHere () initval
+ accum <- newAcDense t initval
out <- unAcM $ f accum
- val <- readAcSparse t accum
+ val <- readAc t accum
return (out, val)
-newAcZero :: STy t -> IO (RepAc t)
-newAcZero = \case
- STNil -> return ()
- STPair{} -> newIORef Nothing
- STEither{} -> newIORef Nothing
- STMaybe _ -> newIORef Nothing
- STArr _ _ -> newIORef Nothing
- STScal sty -> case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> newIORef 0.0
- STF64 -> newIORef 0.0
- STBool -> return ()
- STAccum{} -> error "Nested accumulators"
-
-newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)
-newAcSparse typ prj idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
- (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
- (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
- (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
- (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val
- (STScal sty, SAPHere) -> case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> newIORef val
- STF64 -> newIORef val
- STBool -> return ()
-
- (STPair t1 t2, SAPFst prj') ->
- newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2
- (STPair t1 t2, SAPSnd prj') ->
- newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val
-
- (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
- (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
-
- (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
-
- (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
-
-newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a))
-newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx
+newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
+newAcDense typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val
+ SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val
+ SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val
+ SMTArr _ t1 -> arrayMapM (newAcDense t1) val
+ SMTScal _ -> newIORef val
onehotArray :: Monad m
- => (Rep (AcIdx p a) -> m v) -- ^ the "one"
- -> m v -- ^ the "zero"
- -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
-onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
+ -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
+onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ arrsh = arrayShape ziarr
!linindex = toLinearIndex arrsh arrindex
- in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero)
-
-readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))
-readAcSparse typ val = case typ of
- STNil -> return ()
- STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- STMaybe t -> traverse (readAcSparse t) =<< readIORef val
- STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val
- STScal sty -> case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> readIORef val
- STF64 -> readIORef val
- STBool -> return ()
- STAccum{} -> error "Nested accumulators"
-
-accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
-
- (STPair t1 t2, SAPHere) ->
- case val of
- Nothing -> return ()
- Just (val1, val2) ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
- <*> newAcSparse t2 SAPHere () val2)
- (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
- accumAddSparse t2 SAPHere ac2 () val2)
- (STPair t1 t2, SAPFst prj') ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
- (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
- (STPair t1 t2, SAPSnd prj') ->
- realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
- (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
-
- (STEither{}, SAPHere) ->
+ in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
+
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
+ SMTScal _ -> readIORef val
+
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
+ (SMTLEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
+
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
+
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let (arrindex', idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ref
+ linindex = toLinearIndex arrsh arrindex
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
case val of
Nothing -> return ()
- Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
- Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- (STEither t1 _, SAPLeft prj') ->
- realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
- (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
- Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
- (STEither _ t2, SAPRight prj') ->
- realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
- (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
- Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
-
- (STMaybe{}, SAPHere) ->
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
case val of
Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- (STMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
-
- (STArr _ t1, SAPHere) ->
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
case val of
Nothing -> return ()
Just val' ->
- realiseMaybeSparse ref
- (arrayMapM (newAcSparse t1 SAPHere ()) val')
- (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
- accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
- (STArr n t1, SAPArrIdx prj' _) ->
- let ((arrindex', arrsh'), idx') = idx
- arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
- linindex = toLinearIndex arrsh arrindex
- in realiseMaybeSparse ref
- (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
- (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
-
- (STScal sty, SAPHere) -> AcM $ case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> return ()
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
-
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
-- Try modifying what's already in ref. The 'join' makes the snd
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index be2a4cc..1682303 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -1,12 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
+import Control.DeepSeq
+import Data.Coerce (coerce)
import Data.List (intersperse, intercalate)
import Data.Foldable (toList)
import Data.IORef
-import GHC.TypeError
+import GHC.Exts (withDict)
import Array
import AST
@@ -18,27 +22,20 @@ type family Rep t where
Rep TNil = ()
Rep (TPair a b) = (Rep a, Rep b)
Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))
Rep (TMaybe t) = Maybe (Rep t)
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
Rep (TAccum t) = RepAc t
--- Mutable, represents D2 of t. Has an O(1) zero.
+-- Mutable, represents monoid types t.
type family RepAc t where
RepAc TNil = ()
- RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b))
- RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))
+ RepAc (TPair a b) = (RepAc a, RepAc b)
+ RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))
RepAc (TMaybe t) = IORef (Maybe (RepAc t))
- RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t)))
- RepAc (TScal sty) = RepAcScal sty
- RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
-
-type family RepAcScal t where
- RepAcScal TI32 = ()
- RepAcScal TI64 = ()
- RepAcScal TF32 = IORef Float
- RepAcScal TF64 = IORef Double
- RepAcScal TBool = ()
+ RepAc (TArr n t) = Array n (RepAc t)
+ RepAc (TScal sty) = IORef (ScalRep sty)
newtype Value t = Value { unValue :: Rep t }
@@ -57,8 +54,11 @@ vUnpair (Value (x, y)) = (Value x, Value y)
showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
-showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
-showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
+showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x
+showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y
+showValue _ (STLEither _ _) Nothing = showString "LNil"
+showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
+showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
showValue d (STArr _ t) arr = showParen (d > 10) $
@@ -66,13 +66,13 @@ showValue d (STArr _ t) arr = showParen (d > 10) $
. showString " ["
. foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
. showString "]"
-showValue _ (STScal sty) x = case sty of
- STF32 -> shows x
- STF64 -> shows x
- STI32 -> shows x
- STI64 -> shows x
- STBool -> shows x
-showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSTy 0 t ++ ">"
+showValue d (STScal sty) x = case sty of
+ STF32 -> showsPrec d x
+ STF64 -> showsPrec d x
+ STI32 -> showsPrec d x
+ STI64 -> showsPrec d x
+ STBool -> showsPrec d x
+showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
@@ -80,3 +80,26 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries :: SList STy env -> SList Value env -> [String]
showEntries SNil SNil = []
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
+
+rnfRep :: STy t -> Rep t -> ()
+rnfRep STNil () = ()
+rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y
+rnfRep (STEither a _) (Left x) = rnfRep a x
+rnfRep (STEither _ b) (Right y) = rnfRep b y
+rnfRep (STLEither _ _) Nothing = ()
+rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
+rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
+rnfRep (STMaybe _) Nothing = ()
+rnfRep (STMaybe t) (Just x) = rnfRep t x
+rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr =
+ withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
+rnfRep (STScal t) x = case t of
+ STI32 -> rnf x
+ STI64 -> rnf x
+ STF32 -> rnf x
+ STF64 -> rnf x
+ STBool -> rnf x
+rnfRep STAccum{} _ = error "Cannot rnf accumulators"
+
+instance KnownTy t => NFData (Value t) where
+ rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/Language.hs b/src/Language.hs
index a66b8b6..4e6d604 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -17,6 +17,7 @@ module Language (
import Array
import AST
+import AST.Sparse.Types
import AST.Types
import CHAD.Types
import Data
@@ -149,6 +150,9 @@ infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
+length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
+length_ e = snd_ (shape e)
+
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
@@ -166,11 +170,17 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
-with :: forall t a env acname. KnownTy t => NExpr env (D2 t) -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a (D2 t))
-with a (n :-> b) = NEWith (knownTy @t) a n b
+recompute :: NExpr env a -> NExpr env a
+recompute = NERecompute
+
+with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
+with a (n :-> b) = NEWith (knownMTy @t) a n b
-accum :: KnownTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil
-accum p a b c = NEAccum knownTy p a b c
+accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+
+accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+accumS p a sp b c = NEAccum knownMTy p a sp b c
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
@@ -204,6 +214,10 @@ or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TB
or_ = oper2 OOr
infixr 2 `or_`
+mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a)
+mod_ = oper2 (OMod knownScalTy)
+infixl 7 `mod_`
+
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 84544f8..be98ccf 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM
import Array
import AST
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -71,9 +72,12 @@ data NExpr env t where
-> NExpr env a -> NExpr env b
-> NExpr env t
+ -- fake halfway checkpointing
+ NERecompute :: NExpr env t -> NExpr env t
+
-- accumulation effect on monoids
- NEWith :: STy t -> NExpr env (D2 t) -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a (D2 t))
- NEAccum :: STy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil
+ NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
-- partiality
NEError :: STy a -> String -> NExpr env a
@@ -215,9 +219,10 @@ fromNamedExpr val = \case
(fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
(fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
(go e1) (go e2)
+ NERecompute e -> ERecompute ext (go e)
NEWith t a n b -> EWith ext t (go a) (lambda val n b)
- NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c)
+ NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
NEError t s -> EError ext t s
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 0bf5482..74b6601 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,8 +1,12 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -10,24 +14,31 @@
{-# LANGUAGE TypeOperators #-}
module Simplify (
simplifyN, simplifyFix,
- SimplifyConfig(..), simplifyWith, simplifyFixWith,
+ SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
) where
+import Control.Monad (ap)
+import Data.Bifunctor (first)
import Data.Function (fix)
import Data.Monoid (Any(..))
-import Data.Type.Equality (testEquality)
+
+import Debug.Trace
import AST
import AST.Count
-import CHAD.Types
+import AST.Pretty
+import AST.Sparse.Types
+import AST.UnMonoid (acPrjCompose)
import Data
+import Simplify.TH
--- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.
data SimplifyConfig = SimplifyConfig
+ { scLogging :: Bool
+ }
defaultSimplifyConfig :: SimplifyConfig
-defaultSimplifyConfig = SimplifyConfig
+defaultSimplifyConfig = SimplifyConfig False
simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
@@ -37,13 +48,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplify =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = defaultSimplifyConfig
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
simplifyWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplifyFix = simplifyFixWith defaultSimplifyConfig
@@ -53,22 +64,74 @@ simplifyFixWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
in fix $ \loop e ->
- let (Any act, e') = simplify' e
+ let (act, e') = runSM (simplify' e)
in if act then loop e' else e'
-simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t)
-simplify' = \case
+-- | simplify monad
+newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a))
+ deriving (Functor)
+
+instance Applicative (SM tenv tt env t) where
+ pure x = SM (\_ -> (Any False, x))
+ (<*>) = ap
+
+instance Monad (SM tenv tt env t) where
+ SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx
+
+runSM :: SM env t env t a -> (Bool, a)
+runSM (SM f) = first getAny (f id)
+
+smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
+smReconstruct core = SM (\ctx -> (Any False, ctx core))
+
+class Monad m => ActedMonad m where
+ tellActed :: m ()
+ hideActed :: m a -> m a
+ liftActed :: (Any, a) -> m a
+
+instance ActedMonad ((,) Any) where
+ tellActed = (Any True, ())
+ hideActed (_, x) = (Any False, x)
+ liftActed = id
+
+instance ActedMonad (SM tenv tt env t) where
+ tellActed = SM (\_ -> tellActed)
+ hideActed (SM f) = SM (\ctx -> hideActed (f ctx))
+ liftActed pair = SM (\_ -> pair)
+
+-- more convenient in practice
+acted :: ActedMonad m => m a -> m a
+acted m = tellActed >> m
+
+within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
+within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
+
+simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify' expr
+ | scLogging ?config = do
+ res <- simplify'Rec expr
+ full <- smReconstruct res
+ let printed = ppExpr knownEnv full
+ replace a bs = concatMap (\x -> if x == a then bs else [x])
+ str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed
+ | otherwise = "--- simplify step: " ++ printed
+ traceM str
+ return res
+ | otherwise = simplify'Rec expr
+
+simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify'Rec = \case
-- inlining
ELet _ rhs body
| cheapExpr rhs
- -> acted $ simplify' (subst1 rhs body)
+ -> acted $ simplify' (substInline rhs body)
| Occ lexOcc runOcc <- occCount IZ body
, ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply
|| (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not
- -> acted $ simplify' (subst1 rhs body)
+ -> acted $ simplify' (substInline rhs body)
- -- let splitting
+ -- let splitting / let peeling
ELet _ (EPair _ a b) body ->
acted $ simplify' $
ELet ext a $
@@ -76,13 +139,20 @@ simplify' = \case
subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
IS i -> EVar ext t (IS (IS i)))
body
+ ELet _ (EJust _ a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body
+ ELet _ (EInl _ t2 a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body
+ ELet _ (EInr _ t1 a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body
-- let rotation
- ELet _ (ELet _ rhs a) b ->
+ ELet _ (ELet _ rhs a) b -> do
+ b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b
acted $ simplify' $
ELet ext rhs $
ELet ext a $
- weakenExpr (WCopy WSink) (snd (simplify' b))
+ weakenExpr (WCopy WSink) b'
-- beta rules for products
EFst _ (EPair _ e e')
@@ -100,12 +170,20 @@ simplify' = \case
EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1
- -- let floating to facilitate beta reduction
+ -- let floating
EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
+ EAccum _ t p e1 sp (ELet _ rhs body) acc ->
+ acted $ simplify' $
+ ELet ext rhs $
+ EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc)
+
+ -- let () = e in () ~> e
+ ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
+ acted $ simplify' e1
-- projection down-commuting
EFst _ (ECase _ e1 e2 e3) ->
@@ -114,89 +192,150 @@ simplify' = \case
ESnd _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
+ EFst _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (EFst ext e1) (EFst ext e2) e3
+ ESnd _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
+
+ -- TODO: more array indexing
+ EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
+ EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
- -- TODO: array indexing (index of build, index of fold)
+ -- TODO: more array shape
+ EShape _ (EBuild _ _ e _) -> acted $ simplify' e
- -- TODO: beta rules for maybe
+ -- TODO: more constant folding
+ EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
- -- TODO: constant folding for operations
+ -- inline cheap array constructors
+ ELet _ (EReplicate1Inner _ e1 e2) e3 ->
+ acted $ simplify' $
+ ELet ext (EPair ext e1 e2) $
+ let v = EVar ext (STPair tIx (typeOf e2)) IZ
+ in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3
+ -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is
+ -- -- cheap, which it can't be because (!) is not cheap if you do AD after.
+ -- -- Should do proper SoA representation.
+ -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 ->
+ -- acted $ simplify' $
+ -- ELet ext e1 $
+ -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3
+
+ -- eta rule for unit
+ e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
+ case e of
+ ENil _ -> return e
+ _ -> acted $ return (ENil ext)
+
+ EBuild _ SZ _ e ->
+ acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
- EAccum _ t p e1 e2 acc -> do
- acc' <- simplify' acc
- simplifyOneHotTerm (OneHotTerm t p e1 e2)
- (Any True, ENil ext)
- (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc'))
- EPlus _ _ (EZero _ _) e -> acted $ simplify' e
- EPlus _ _ e (EZero _ _) -> acted $ simplify' e
- EOneHot _ t p e1 e2 ->
- simplifyOneHotTerm (OneHotTerm t p e1 e2)
- (Any True, EZero ext t)
- (\e -> (Any True, e))
- (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2'))
+ EAccum _ t p e1 sp e2 acc -> do
+ e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc
+ simplifyOHT (OneHotTerm SAID t p e1' sp e2')
+ (acted $ return (ENil ext))
+ (\sp' (InContext w wrap e) -> do
+ e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e
+ return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')))
+ (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do
+ -- The acted management here is a hideous mess.
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2''
+ return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')))
+ EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
+ EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
+ EOneHot _ t p e1 e2 -> do
+ e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
+ e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
+ simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2')
+ (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
+ (\sp' (InContext _ wrap e) ->
+ case isDense t sp' of
+ Just Refl -> do
+ e' <- hideActed $ within wrap $ simplify' e
+ return (wrap e')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
+ (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) ->
+ case isDense (acPrjTy p' t') sp' of
+ Just Refl -> do
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2''
+ return (wrap $ EOneHot ext t' p' e1''' e2''')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
-- type-specific equations for plus
- EPlus _ STNil _ _ -> (Any True, ENil ext)
+ EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
+ acted $ return (ENil ext)
- EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
- acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2))
- EPlus _ STPair{} ENothing{} e -> acted $ simplify' e
- EPlus _ STPair{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) ->
+ acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)
- EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) ->
- acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2))
- EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) ->
- acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2))
- EPlus _ STEither{} ENothing{} e -> acted $ simplify' e
- EPlus _ STEither{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) ->
+ acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2)
+ EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) ->
+ acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2)
+ EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e
+ EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e
- EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) ->
+ EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) ->
acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
- EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e
- EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e
+ EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
+ EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
- ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
- EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
- EFst _ e -> EFst ext <$> simplify' e
- ESnd _ e -> ESnd ext <$> simplify' e
+ ELet _ a b -> [simprec| ELet ext *a *b |]
+ EPair _ a b -> [simprec| EPair ext *a *b |]
+ EFst _ e -> [simprec| EFst ext *e |]
+ ESnd _ e -> [simprec| ESnd ext *e |]
ENil _ -> pure $ ENil ext
- EInl _ t e -> EInl ext t <$> simplify' e
- EInr _ t e -> EInr ext t <$> simplify' e
- ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
+ EInl _ t e -> [simprec| EInl ext t *e |]
+ EInr _ t e -> [simprec| EInr ext t *e |]
+ ECase _ e a b -> [simprec| ECase ext *e *a *b |]
ENothing _ t -> pure $ ENothing ext t
- EJust _ e -> EJust ext <$> simplify' e
- EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
+ EJust _ e -> [simprec| EJust ext *e |]
+ EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]
+ ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
+ ELInl _ t e -> [simprec| ELInl ext t *e |]
+ ELInr _ t e -> [simprec| ELInr ext t *e |]
+ ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
- EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c
- ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
- EUnit _ e -> EUnit ext <$> simplify' e
- EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
- EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
- EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
+ EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
+ EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
+ ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
+ EUnit _ e -> [simprec| EUnit ext *e |]
+ EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
+ EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
+ EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EConst _ t v -> pure $ EConst ext t v
- EIdx0 _ e -> EIdx0 ext <$> simplify' e
- EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
- EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
- EShape _ e -> EShape ext <$> simplify' e
- EOp _ op e -> EOp ext op <$> simplify' e
- ECustom _ s t p a b c e1 e2 ->
- ECustom ext s t p
- <$> (let ?accumInScope = False in simplify' a)
- <*> (let ?accumInScope = False in simplify' b)
- <*> (let ?accumInScope = False in simplify' c)
- <*> simplify' e1 <*> simplify' e2
- EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EZero _ t -> pure $ EZero ext t
- EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
+ EIdx0 _ e -> [simprec| EIdx0 ext *e |]
+ EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
+ EIdx _ a b -> [simprec| EIdx ext *a *b |]
+ EShape _ e -> [simprec| EShape ext *e |]
+ EOp _ op e -> [simprec| EOp ext op *e |]
+ ECustom _ s t p a b c e1 e2 -> do
+ a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a)
+ b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b)
+ c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c)
+ e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
+ pure (ECustom ext s t p a' b' c' e1' e2')
+ ERecompute _ e -> [simprec| ERecompute ext *e |]
+ EWith _ t e1 e2 -> do
+ e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
+ pure (EWith ext t e1' e2')
+ EZero _ t e -> [simprec| EZero ext t *e |]
+ EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
-acted :: (Any, a) -> (Any, a)
-acted (_, x) = (Any True, x)
-
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
EVar{} -> True
@@ -204,6 +343,7 @@ cheapExpr = \case
EConst{} -> True
EFst _ e -> cheapExpr e
ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
_ -> False
-- | This can be made more precise by tracking (and not counting) adds on
@@ -222,6 +362,10 @@ hasAdds = \case
ENothing _ _ -> False
EJust _ e -> hasAdds e
EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
+ ELNil _ _ _ -> False
+ ELInl _ _ e -> hasAdds e
+ ELInr _ _ e -> hasAdds e
+ ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c
EConstArr _ _ _ _ -> False
EBuild _ _ a b -> hasAdds a || hasAdds b
EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
@@ -238,8 +382,10 @@ hasAdds = \case
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
EWith _ _ a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ _ _ -> True
- EZero _ _ -> False
+ ERecompute _ e -> hasAdds e
+ EAccum _ _ _ _ _ _ _ -> True
+ EZero _ _ e -> hasAdds e
+ EDeepZero _ _ e -> hasAdds e
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
@@ -252,49 +398,202 @@ checkAccumInScope = \case SNil -> False
check STNil = False
check (STPair s t) = check s || check t
check (STEither s t) = check s || check t
+ check (STLEither s t) = check s || check t
check (STMaybe t) = check t
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
-data OneHotTerm env p a b where
- OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b
-deriving instance Show (OneHotTerm env p a b)
-
-simplifyOneHotTerm :: OneHotTerm env p a b
- -> (Any, r) -- ^ Zero case (onehot is actually zero)
- -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
- -> (Any, r)
-simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2
- = do (Any True, ()) -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
-simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e
-simplifyOneHotTerm term _ _ k = k term
-
-concatOneHots :: STy a
- -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
-concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
- (_, SAPHere) -> k prj2 idx2
-
- (STPair a _, SAPFst prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12
- (STPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12
-
- (STEither a _, SAPLeft prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (STEither _ b, SAPRight prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
-
- (STMaybe a, SAPJust prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
-
- (STArr n a, SAPArrIdx prj1' _) ->
- concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
+data OneHotTerm dense env a where
+ OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a
+deriving instance Show (OneHotTerm dense env a)
+
+data InContext f env (a :: Ty) where
+ InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a
+
+simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do
+ val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val
+ return $ OneHotTerm dense t prj idx sp val'
+
+simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a)
+simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) =
+ unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 ->
+ acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' ->
+ return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2)
+simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht
+
+simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val))
+ | Just Refl <- isDense (acPrjTy prj1 t1) sp =
+ let idx2' :: Ex env (AcIdx dense p2 c)
+ idx2' = case dense of
+ SAID -> reduceAcIdx t2 prj2 idx2
+ SAIS -> idx2
+ in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' ->
+ acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val
+simplifyOHT_concat oht = return oht
+
+-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is
+-- -- dense, then the Sparse in the output will also be dense. This property is
+-- -- used when simplifying EOneHot, which cannot represent sparsity.
+simplifyOHT :: ActedMonad m => OneHotTerm dense env a
+ -> m r -- ^ Zero case (onehot is actually zero)
+ -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot)
+ -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified
+ -> m r
+simplifyOHT oht kzero ktriv k = do
+ -- traceM $ "sOHT: input " ++ show oht
+ oht1 <- simplifyOHT_recogniseMonoid oht
+ -- traceM $ "sOHT: recog " ++ show oht1
+ InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1
+ -- traceM $ "sOHT: unspa " ++ show oht2
+ oht3 <- simplifyOHT_concat oht2
+ -- traceM $ "sOHT: conca " ++ show oht3
+ -- traceM ""
+ case oht3 of
+ OneHotTerm _ _ _ _ _ EZero{} -> kzero
+ OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val)
+ _ -> k (InContext w1 wrap1 oht3)
+
+-- Sets the acted flag whenever a non-trivial projection is returned or the
+-- output Sparse is different from the input Sparse.
+unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a'
+ -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s)
+ -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r
+unsparseOneHotD topsp topval k = case (topsp, topval) of
+ -- eliminate always-Just sparse onehot
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k
+
+ -- expand the top levels of a onehot for a sparse type into a onehot for the
+ -- corresponding non-sparse type
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPFst spprj) idx' s1' e'
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPSnd spprj) idx' s1' e'
+ (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPLeft spprj) idx' s1' e'
+ (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPRight spprj) idx' s1' e'
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPJust spprj) idx' s1' e'
+ (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val)
+ | Dict <- styKnown (typeOf idx) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' ->
+ acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e'
+
+ -- anything else we don't know how to improve
+ _ -> k WId id SAPHere (ENil ext) topsp topval
+
+{-
+unsparseOneHotS :: ActedMonad m
+ => Sparse a a' -> Ex env a'
+ -> (forall b. Sparse a b -> Ex env b -> m r) -> m r
+unsparseOneHotS topsp topval k = case (topsp, topval) of
+ -- order is relevant to make sure we set the acted flag correctly
+ (SpAbsent, v@ENil{}) -> k SpAbsent v
+ (SpAbsent, v@EZero{}) -> k SpAbsent v
+ (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+
+ -- the unsparsifying
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k
+
+ -- recursion
+ -- TODO: coproducts could safely become projections as they do not need
+ -- zeroinfo. But that would only work if the coproduct is at the top, because
+ -- as soon as we hit a product, we need zeroinfo to make it a projection and
+ -- we don't have that.
+ (SpSparse s, e) -> k (SpSparse s) e
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' ->
+ acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext))
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' ->
+ acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do
+ case s2 of SpAbsent -> pure () ; _ -> tellActed
+ k (SpLEither s1' SpAbsent) (ELInl ext STNil e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do
+ case s1 of SpAbsent -> pure () ; _ -> tellActed
+ acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e')
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' ->
+ k (SpMaybe s1') (EJust ext e')
+ (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' ->
+ k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e')
+ _ -> _
+-}
+
+-- | Recognises 'EZero' and 'EOneHot'.
+recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
+recogniseMonoid _ e@EOneHot{} = return e
+recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext)
+recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
+ ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
+ (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2)
+ (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
+ (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
+ (a', b') -> return $ EPair ext a' b'
+recogniseMonoid typ@(SMTLEither t1 t2) expr =
+ case expr of
+ ELNil{} -> acted $ return $ EZero ext typ (ENil ext)
+ ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
+ _ -> return expr
+recogniseMonoid typ@(SMTMaybe t1) expr =
+ case expr of
+ ENothing{} -> acted $ return $ EZero ext typ (ENil ext)
+ EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ _ -> return expr
+recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
+ acted $ do
+ e' <- recogniseMonoid t e
+ return $
+ ELet ext e' $
+ EOneHot ext typ (SAPArrIdx SAPHere)
+ (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ))))
+ (ENil ext))
+ (EVar ext (fromSMTy t) IZ)
+recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
+ (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ _ -> return e
+recogniseMonoid _ e = return e
+
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a)
+reduceAcIdx topty topprj e = case (topty, topprj) of
+ (_, SAPHere) -> ENil ext
+ (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
+ (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
+ (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e
+ (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e
+ (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e
+ (SMTArr _ t, SAPArrIdx p) ->
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (efst e1) (reduceAcIdx t p e2)
+
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
+zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
+ where
+ -- invariant: AcIdx expression is duplicable
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go t SAPHere _ e = makeZeroInfo t e
+ go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
+ go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
+ go SMTLEither{} _ _ _ = ENil ext
+ go SMTMaybe{} _ _ _ = ENil ext
+ go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)
diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs
new file mode 100644
index 0000000..2e0076a
--- /dev/null
+++ b/src/Simplify/TH.hs
@@ -0,0 +1,80 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Simplify.TH (simprec) where
+
+import Data.Bifunctor (first)
+import Data.Char
+import Data.List (foldl1')
+import Language.Haskell.TH
+import Language.Haskell.TH.Quote
+import Text.ParserCombinators.ReadP
+
+
+-- [simprec| EPair ext *a *b |]
+-- ~>
+-- do a' <- within (\a' -> EPair ext a' b) (simplify' a)
+-- b' <- within (\b' -> EPair ext a' b') (simplify' b)
+-- pure (EPair ext a' b')
+
+simprec :: QuasiQuoter
+simprec = QuasiQuoter
+ { quoteDec = \_ -> fail "simprec used outside of expression context"
+ , quoteType = \_ -> fail "simprec used outside of expression context"
+ , quoteExp = handler
+ , quotePat = \_ -> fail "simprec used outside of expression context"
+ }
+
+handler :: String -> Q Exp
+handler str =
+ case readP_to_S pTemplate str of
+ [(template, "")] -> generate template
+ _:_:_ -> fail "simprec: template grammar ambiguous"
+ _ -> fail "simprec: could not parse template"
+
+generate :: Template -> Q Exp
+generate (Template topitems) =
+ let takePrefix (Plain x : xs) = first (x:) (takePrefix xs)
+ takePrefix xs = ([], xs)
+
+ itemVar "" = error "simprec: empty item name?"
+ itemVar name@(c:_) | isLower c = VarE (mkName name)
+ | isUpper c = ConE (mkName name)
+ | otherwise = error "simprec: non-letter item name?"
+
+ loop :: Exp -> [Item] -> Q [Stmt]
+ loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)]
+ loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs
+ loop yet (Recurse x : xs) = do
+ primeName <- newName (x ++ "'")
+ let appPrePrime e (Plain y) = e `AppE` itemVar y
+ appPrePrime e (Recurse y) = e `AppE` itemVar y
+ let stmt = BindS (VarP primeName) $
+ VarE (mkName "within")
+ `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs)
+ `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x))
+ stmts <- loop (yet `AppE` VarE primeName) xs
+ return (stmt : stmts)
+
+ (prefix, items') = takePrefix topitems
+ in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items'
+
+data Template = Template [Item]
+ deriving (Show)
+
+data Item = Plain String | Recurse String
+ deriving (Show)
+
+pTemplate :: ReadP Template
+pTemplate = do
+ items <- many (skipSpaces >> pItem)
+ skipSpaces
+ eof
+ return (Template items)
+
+pItem :: ReadP Item
+pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName)
+
+pName :: ReadP String
+pName = do
+ c1 <- satisfy (\c -> isAlpha c || c == '_')
+ cs <- munch (\c -> isAlphaNum c || c `elem` "_'")
+ return (c1:cs)