summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-06 22:50:06 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-06 22:50:06 +0200
commit56056c98b2e3dce65a0e42bce0410c083fd1f8be (patch)
tree8db2d1be037f8f980c3d1bf76ff9078048f33d63
parent7bd37711ffecb7a0e202ecfd717e3a4cbbe6074f (diff)
WIP mixed static/dynamic sparsitysparse
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/AST.hs33
-rw-r--r--src/AST/Accum.hs17
-rw-r--r--src/AST/Bindings.hs2
-rw-r--r--src/AST/Count.hs6
-rw-r--r--src/AST/Env.hs58
-rw-r--r--src/AST/Sparse.hs434
-rw-r--r--src/AST/Types.hs2
-rw-r--r--src/CHAD.hs298
-rw-r--r--src/CHAD/Accum.hs27
-rw-r--r--src/CHAD/EnvDescr.hs20
-rw-r--r--src/CHAD/Types.hs16
-rw-r--r--src/Data/VarMap.hs4
13 files changed, 747 insertions, 172 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index b0ed639..b8510d2 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -18,13 +18,13 @@ library
AST.Count
AST.Env
AST.Pretty
+ AST.Sparse
AST.SplitLets
AST.Types
AST.UnMonoid
AST.Weaken
AST.Weaken.Auto
CHAD
- CHAD.Accum
CHAD.EnvDescr
CHAD.Top
CHAD.Types
diff --git a/src/AST.hs b/src/AST.hs
index 149cddd..0000836 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -503,11 +503,40 @@ eshapeEmpty (SS n) e =
(EConst ext STI64 0)))
(eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
-ezeroD2 :: STy t -> Ex env (D2 t)
-ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext)
+-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
+-- ezeroD2 t ezi = EZero ext (d2M t) ezi
-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev
+
+eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
+eunPair (EPair _ e1 e2) k = k WId e1 e2
+eunPair e k =
+ elet e $
+ k WSink
+ (EFst ext (evar IZ))
+ (ESnd ext (evar IZ))
+
+elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
+elet rhs body
+ | Dict <- styKnown (typeOf rhs)
+ = ELet ext rhs body
+
+emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
+emaybe e a b
+ | STMaybe t <- typeOf e
+ , Dict <- styKnown t
+ = EMaybe ext a b e
+
+elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c
+elcase e a b c
+ | STLEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ELCase ext e a b c
+
+evar :: KnownTy a => Idx env a -> Ex env a
+evar = EVar ext knownTy
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 03369c8..1101cc0 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,14 +1,11 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
import AST.Types
-import CHAD.Types
import Data
@@ -75,20 +72,6 @@ tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
-lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil
-lemZeroInfoD2 STNil = Refl
-lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl
-lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl
-lemZeroInfoD2 (STScal STI32) = Refl
-lemZeroInfoD2 (STScal STI64) = Refl
-lemZeroInfoD2 (STScal STF32) = Refl
-lemZeroInfoD2 (STScal STF64) = Refl
-lemZeroInfoD2 (STScal STBool) = Refl
-lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program"
-
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 745a93b..2310f4b 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -69,7 +69,7 @@ collectBindings = \env -> fst . go env WId
where
go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
go _ _ SETop = (BTop, WId)
- go (ty `SCons` env) w (SEYes sub) =
+ go (ty `SCons` env) w (SEYesR sub) =
let (bs, w') = go env (WPop w) sub
in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 0c682c6..03a36f6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -154,7 +154,7 @@ deleteUnused (_ `SCons` env) OccEnd k =
deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
deleteUnused env occenv $ \sub ->
case count of Zero -> k (SENo sub)
- _ -> k (SEYes sub)
+ _ -> k (SEYesR sub)
unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
@@ -163,7 +163,7 @@ unsafeWeakenWithSubenv = \sub ->
Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
where
sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkViaSubenv IZ (SEYes _) = Just IZ
+ sinkViaSubenv IZ (SEYesR _) = Just IZ
sinkViaSubenv IZ (SENo _) = Nothing
- sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
+ sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
index 4f34166..bc2b9e0 100644
--- a/src/AST/Env.hs
+++ b/src/AST/Env.hs
@@ -1,59 +1,73 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Env where
+import Data.Type.Equality
+
+import AST.Sparse
import AST.Weaken
import Data
-- | @env'@ is a subset of @env@: each element of @env@ is either included in
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv env env' where
- SETop :: Subenv '[] '[]
- SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env'
-deriving instance Show (Subenv env env')
+data Subenv' s env env' where
+ SETop :: Subenv' s '[] '[]
+ SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env')
+ SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env'
+deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env')
+
+type Subenv = Subenv' (:~:)
+type SubenvS = Subenv' Sparse
+
+pattern SEYesR :: forall tenv tenv'. ()
+ => forall t env env'. (tenv ~ t : env, tenv' ~ t : env')
+ => Subenv env env' -> Subenv tenv tenv'
+pattern SEYesR s = SEYes Refl s
+
+{-# COMPLETE SETop, SEYesR, SENo #-}
-subList :: SList f env -> Subenv env env' -> SList f env'
+subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'
subList SNil SETop = SNil
-subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
+subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)
subList (SCons _ xs) (SENo sub) = subList xs sub
-subenvAll :: SList f env -> Subenv env env
+subenvAll :: IsSubType s => SList f env -> Subenv' s env env
subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
+subenvAll (SCons _ env) = SEYes subtFull (subenvAll env)
-subenvNone :: SList f env -> Subenv env '[]
+subenvNone :: SList f env -> Subenv' s env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
-subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
-subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
+subenvOnehot :: IsSubType s => SList f env -> Idx env t -> Subenv' s env '[t]
+subenvOnehot (SCons _ env) IZ = SEYes subtFull (subenvNone env)
subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
subenvOnehot SNil i = case i of {}
-subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3
+subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
subenvCompose SETop SETop = SETop
-subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)
-subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
+subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
-subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1')
+subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')
subenvConcat sub1 SETop = sub1
-subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2)
+subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)
subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
-sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0
+sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0
sinkWithSubenv SETop = WId
-sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub
+sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub
sinkWithSubenv (SENo sub) = sinkWithSubenv sub
-wUndoSubenv :: Subenv env env' -> env' :> env
+wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env
wUndoSubenv SETop = WId
-wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)
wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
new file mode 100644
index 0000000..09dbc70
--- /dev/null
+++ b/src/AST/Sparse.hs
@@ -0,0 +1,434 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-}
+module AST.Sparse where
+
+import Data.Kind (Constraint, Type)
+import Data.Type.Equality
+
+import AST
+
+
+data Sparse t t' where
+ SpDense :: Sparse t t
+ SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
+ SpAbsent :: Sparse t TNil
+
+ SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
+ SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
+ SpLeft :: Sparse a a' -> Sparse (TLEither a b) a'
+ SpRight :: Sparse b b' -> Sparse (TLEither a b) b'
+ SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
+ SpJust :: Sparse t t' -> Sparse (TMaybe t) t'
+ SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+deriving instance Show (Sparse t t')
+
+applySparse :: Sparse t t' -> STy t -> STy t'
+applySparse SpDense t = t
+applySparse (SpSparse s) t = STMaybe (applySparse s t)
+applySparse SpAbsent _ = STNil
+applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1
+applySparse (SpRight s) (STLEither _ t2) = applySparse s t2
+applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+applySparse (SpJust s) (STMaybe t) = applySparse s t
+applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+
+
+class IsSubType s where
+ type IsSubTypeSubject s (f :: k -> Type) :: Constraint
+ subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
+ subtTrans :: s a b -> s b c -> s a c
+ subtFull :: s a a
+
+instance IsSubType (:~:) where
+ type IsSubTypeSubject (:~:) f = ()
+ subtApply = gcastWith
+ subtTrans = trans
+ subtFull = Refl
+
+instance IsSubType Sparse where
+ type IsSubTypeSubject Sparse f = f ~ STy
+ subtApply = applySparse
+
+ subtTrans SpDense s = s
+ subtTrans s SpDense = s
+ subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
+ subtTrans _ SpAbsent = SpAbsent
+ subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2)
+ subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2)
+ subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2)
+ subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2)
+ subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
+ subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2
+ subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
+ subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2)
+ subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2)
+ subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
+
+ subtFull = SpDense
+
+
+data SBool b where
+ SF :: SBool False
+ ST :: SBool True
+deriving instance Show (SBool b)
+
+data Injection sp a b where
+ -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
+ -- 'sparsePlusS' can provide injections even if the caller doesn't require
+ -- them. This eliminates pointless checks.
+ Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
+ Noinj :: Injection False a b
+
+withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
+withInj (Inj f) k = Inj (k f)
+withInj Noinj _ = Noinj
+
+withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
+ -> ((forall e. Ex e a1 -> Ex e b1)
+ -> (forall e. Ex e a2 -> Ex e b2)
+ -> (forall e'. Ex e' a' -> Ex e' b'))
+ -> Injection sp a' b'
+withInj2 (Inj f) (Inj g) k = Inj (k f g)
+withInj2 Noinj _ _ = Noinj
+withInj2 _ Noinj _ = Noinj
+
+-- | This function produces quadratically-sized code in the presence of nested
+-- dynamic sparsity. しょうがない。
+sparsePlusS
+ :: SBool inj1 -> SBool inj2
+ -> SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
+ -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+-- nil override
+sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext)
+
+-- simplifications
+sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
+ sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b)
+sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
+ sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext))
+
+sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
+ let ta = applySparse sp1 (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
+ let tb = applySparse sp2 (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
+ let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
+ let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
+ let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
+ let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
+sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
+
+-- TODO: sparse of Just is just Maybe
+
+-- dense plus
+sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b)
+
+-- handle absents
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b)
+sparsePlusS ST _ t SpAbsent sp2 k =
+ k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b)
+
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a)
+sparsePlusS _ ST t sp1 SpAbsent k =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a)
+
+-- double sparse yields sparse
+sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- single sparse can yield non-sparse if the other argument is always present
+sparsePlusS SF _ t (SpSparse sp1) sp2 k =
+ sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
+ k sp3 Noinj (Inj inj2)
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (inj2 (evar IZ))
+ (plus (evar IZ) (evar (IS IZ))))
+sparsePlusS ST _ t (SpSparse sp1) sp2 k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> EJust ext (inj2 b))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (EJust ext (inj2 (evar IZ)))
+ (EJust ext (plus (evar IZ) (evar (IS IZ)))))
+sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
+ sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
+ k sp3 inj2 inj1 (flip plus)
+
+-- products
+sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
+ sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
+ sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
+ k (SpPair sp3a sp3b)
+ (withInj2 minj13a minj13b $ \inj13a inj13b ->
+ \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
+ (withInj2 minj23a minj23b $ \inj23a inj23b ->
+ \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
+ (\x1 x2 ->
+ eunPair x1 $ \w1 x1a x1b ->
+ eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
+ EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
+sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k
+sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k
+
+-- coproducts
+sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
+ sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
+ sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
+ let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
+ inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
+ inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
+ in
+ k (SpLEither sp3a sp3b)
+ (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
+ (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
+ (\x1 x2 ->
+ elet x2 $
+ elcase (weakenExpr WSink x1)
+ (elcase (evar IZ)
+ nil
+ (inl (inj23a (evar IZ)))
+ (inr (inj23b (evar IZ))))
+ (elcase (evar (IS IZ))
+ (inl (inj13a (evar IZ)))
+ (inl (plusa (evar (IS IZ)) (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
+ (elcase (evar (IS IZ))
+ (inr (inj13b (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
+ (inr (plusb (evar (IS IZ)) (evar IZ)))))
+sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k
+sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k
+
+-- coproducts with partially known arguments: if we have a non-nil
+-- always-present coproduct argument, the result is dense, otherwise we
+-- introduce sparsity
+sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k =
+ sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa ->
+ k (SpLeft sp3a)
+ (Inj inj13a)
+ Noinj
+ (\x1 x2 ->
+ elet x1 $
+ elcase (weakenExpr WSink x2)
+ (inj13a (evar IZ))
+ (plusa (evar (IS IZ)) (evar IZ))
+ (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr"))
+
+sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k =
+ sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa ->
+ k (SpSparse (SpLeft sp3a))
+ (Inj $ \x1 -> EJust ext (inj13a x1))
+ (Inj $ \x2 ->
+ elcase x2
+ (ENothing ext (applySparse sp3a (fromSMTy ta)))
+ (EJust ext (inj23a (evar IZ)))
+ (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr"))
+ (\x1 x2 ->
+ elet x1 $
+ EJust ext $
+ elcase (weakenExpr WSink x2)
+ (inj13a (evar IZ))
+ (plusa (evar (IS IZ)) (evar IZ))
+ (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr"))
+
+sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k =
+ sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa)
+sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k
+sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k
+
+sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k =
+ sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb ->
+ k (SpRight sp3b)
+ (Inj inj13b)
+ Noinj
+ (\x1 x2 ->
+ elet x1 $
+ elcase (weakenExpr WSink x2)
+ (inj13b (evar IZ))
+ (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll")
+ (plusb (evar (IS IZ)) (evar IZ)))
+
+sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k =
+ sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb ->
+ k (SpSparse (SpRight sp3b))
+ (Inj $ \x1 -> EJust ext (inj13b x1))
+ (Inj $ \x2 ->
+ elcase x2
+ (ENothing ext (applySparse sp3b (fromSMTy tb)))
+ (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll")
+ (EJust ext (inj23b (evar IZ))))
+ (\x1 x2 ->
+ elet x1 $
+ EJust ext $
+ elcase (weakenExpr WSink x2)
+ (inj13b (evar IZ))
+ (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll")
+ (plusb (evar (IS IZ)) (evar IZ)))
+
+sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k =
+ sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb)
+sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k
+sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k
+
+-- dense same-branch coproducts simply recurse
+sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k =
+ sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus ->
+ k (SpLeft sp3) inj1 inj2 plus
+sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k =
+ sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus ->
+ k (SpRight sp3) inj1 inj2 plus
+
+-- dense, mismatched coproducts are valid as long as we don't actually invoke
+-- plus at runtime (injections are fine)
+sparsePlusS SF SF _ SpLeft{} SpRight{} k =
+ k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr")
+sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k =
+ k (SpRight sp2) Noinj (Inj id)
+ (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr")
+sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k =
+ k (SpLeft sp1) (Inj id) Noinj
+ (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll")
+sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k =
+ -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it.
+ k (SpLEither sp1 sp2)
+ (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a)
+ (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b)
+ (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr")
+
+sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k = -- the errors are not flipped, but eh
+ sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus)
+
+-- maybe
+sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpMaybe sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k
+sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k
+
+-- maybe with partially known arguments: if we have an always-present Just
+-- argument, the result is dense, otherwise we introduce sparsity by weakening
+-- to SpMaybe
+sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k =
+ sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus ->
+ k (SpJust sp3)
+ (Inj inj1)
+ Noinj
+ (\a b ->
+ elet a $
+ emaybe (weakenExpr WSink b)
+ (inj1 (evar IZ))
+ (plus (evar (IS IZ)) (evar IZ)))
+sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpMaybe sp3)
+ (Inj $ \a -> EJust ext (inj1 a))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet a $
+ emaybe (weakenExpr WSink b)
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ))))
+
+sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k =
+ sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus)
+sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k
+sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k
+
+-- dense same-branch maybes simply recurse
+sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k =
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus ->
+ k (SpJust sp3) inj1 inj2 plus
+
+-- dense array cotangents simply recurse
+sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
+ k (SpArr sp3)
+ (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
+ (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+ (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
+ (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k
+sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index a3b7302..42bfb92 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -5,9 +5,9 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE TypeData #-}
module AST.Types where
import Data.Int (Int32, Int64)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index df792ce..b5a9af0 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -42,8 +42,8 @@ import AST
import AST.Bindings
import AST.Count
import AST.Env
+import AST.Sparse
import AST.Weaken.Auto
-import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -65,7 +65,7 @@ tapeTy (SCons t ts) = STPair t (tapeTy ts)
bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds
-> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
bindingsCollectTape BTop SETop _ = ENil ext
-bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w =
EPair ext (EVar ext t (w @> IZ))
(bindingsCollectTape binds sub (w .> WSink))
bindingsCollectTape (BPush binds _) (SENo sub) w =
@@ -227,26 +227,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
-> D2Op (TScal a) t
@@ -261,11 +272,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -293,7 +304,7 @@ conv1Idx (IS i) = IS (conv1Idx i)
data Idx2 env sto t
= Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
- | Idx2Me (Idx (Select env sto "merge") t)
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
| Idx2Di (Idx (Select env sto "discr") t)
conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
@@ -317,44 +328,127 @@ conv2Idx DTop i = case i of {}
------------------------------------ MONOIDS -----------------------------------
-zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
-zeroTup SNil = ENil ext
-zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t)
-
-
------------------------------------- SUBENVS -----------------------------------
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0))
+zeroTup SNil _ = ENil ext
+zeroTup (t `SCons` env) w =
+ EPair ext (zeroTup env (WPop w))
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+
+----------------------------------- SPARSITY -----------------------------------
+
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
+
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse _ SpDense _ e = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STEither t1 t2) (SpLeft s) epr e =
+ let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL")
+ in ELInl ext (d2 t2) (expandSparse t1 s epr' e)
+expandSparse (STEither t1 t2) (SpRight s) epr e =
+ let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ)
+ in ELInr ext (d2 t1) (expandSparse t2 s epr' e)
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLeft s) epr e =
+ let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL")
+ in ELInl ext (d2 t2) (expandSparse t1 s epr' e)
+expandSparse (STLEither t1 t2) (SpRight s) epr e =
+ let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ)
+ in ELInr ext (d2 t1) (expandSparse t2 s epr' e)
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal sty) _ _ _ = case sty of {} -- SpDense and SpSparse handled already
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+
+sparsePlus
+ :: SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus
subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
+ -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2
+ -> (forall env3. SubenvS (D2E env) env3
+ -> SubenvS env3 env1
+ -> SubenvS env3 env2
+ -> (Ex exenv (Tup env1)
+ -> Ex exenv (Tup env2)
+ -> Ex exenv (Tup env3))
-> r)
-> r
subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
+subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
+ k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 ->
ELet ext e1 $
EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
(weakenExpr WSink e2))
(ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
+subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
+ k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 ->
ELet ext e2 $
EPair ext (pl (weakenExpr WSink e1)
(EFst ext (EVar ext (typeOf e2) IZ)))
(ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
+subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
+ k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 ->
ELet ext e1 $
ELet ext (weakenExpr WSink e2) $
EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
@@ -363,22 +457,44 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
(ESnd ext (EVar ext (typeOf e1) (IS IZ)))
(ESnd ext (EVar ext (typeOf e2) IZ)))
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t)
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
-assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
+assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
+
fromArrayValId :: Maybe (ValId t) -> Maybe Int
fromArrayValId (Just (VIArr i _)) = Just i
fromArrayValId _ = Nothing
@@ -422,7 +538,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
k (storepl `DPush` (t, vid, SAccum))
envpro
prosub
- (SEYes accrevsub)
+ (SEYesR accrevsub)
(VarMap.sink1 accumMap)
(\shbinds ->
autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
@@ -449,7 +565,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
k (storepl `DPush` (t, vid, SAccum))
(t `SCons` envpro)
- (SEYes prosub)
+ (SEYesR prosub)
(SENo accrevsub)
(let accumMap' = VarMap.sink1 accumMap
in case fromArrayValId vid of
@@ -499,19 +615,21 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
data Ret env0 sto t =
- forall shbinds tapebinds env0Merge.
+ forall shbinds tapebinds contribs.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (forall sd. Sparse (D2 t) sd
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
deriving instance Show (Ret env0 sto t)
data RetPair env0 sto env shbinds tapebinds t =
- forall env0Merge.
+ forall contribs.
RetPair (Ex (Append shbinds env) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (forall sd. Sparse (D2 t) sd
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
data Rets env0 sto env list =
@@ -569,18 +687,24 @@ freezeRet :: Descr env sto
freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (subList (d2e (select SMerge descr)) sub)
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tape (subList (bindingsBinds e0) subtape)
- &. #shbinds (bindingsBinds e0)
- &. #d2ace (d2ace (select SAccum descr))
- &. #tl (desD1E descr))
+ (ELet ext (weakenExpr (autoWeak library
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
- expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
---------------------------- THE CHAD TRANSFORMATION ---------------------------
@@ -596,21 +720,21 @@ drev des accumMap = \case
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvOnehot (select SMerge des) tupI)
+ (subenvOnehot (d2e (select SMerge des)) tupI)
(EPair ext (ENil ext) (EVar ext (d2 t) IZ))
Idx2Di _ ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
ELet _ (rhs :: Expr _ _ a) body
@@ -621,7 +745,7 @@ drev des accumMap = \case
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
, Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
+ let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
(weakenExpr wbody0' body1)
@@ -637,7 +761,7 @@ drev des accumMap = \case
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
@@ -649,16 +773,13 @@ drev des accumMap = \case
subtape
(EPair ext a1 b1)
subBoth
- (EMaybe ext
- (zeroTup (subList (select SMerge des) subBoth))
- (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
- (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
- (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ))
+ (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ))
EFst _ e
| Ret e0 subtape e1 sub e2 <- drev des accumMap e
@@ -732,7 +853,7 @@ drev des accumMap = \case
ECase ext e1
(letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
(letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
- (SEYes subtapeE)
+ (SEYesR subtapeE)
(EFst ext (EVar ext tPrimal IZ))
subOut
(ELet ext
@@ -801,7 +922,7 @@ drev des accumMap = \case
(weakenExpr (WCopy WSink) e2))
Nonlinear d2opfun ->
Ret (e0 `BPush` (d1 (typeOf e), e1))
- (SEYes subtape)
+ (SEYesR subtape)
(d1op op $ EVar ext (d1 (typeOf e)) IZ)
sub
(ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
@@ -816,7 +937,7 @@ drev des accumMap = \case
`BPush` (typeOf b1, weakenExpr WSink b1)
`BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
`BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
- (SEYes (SENo (SENo (SENo subtape))))
+ (SEYesR (SENo (SENo (SENo subtape))))
(EFst ext (EVar ext (typeOf pr) (IS IZ)))
bsub
(ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
@@ -865,7 +986,7 @@ drev des accumMap = \case
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
+ let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
@@ -894,7 +1015,7 @@ drev des accumMap = \case
in EPair ext (weakenExpr w e1) (collectexpr w)))
`BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
- (SEYes (SENo (SEYes SETop)))
+ (SEYesR (SENo (SEYesR SETop)))
(emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvCompose subMergeUsed proSub)
@@ -981,7 +1102,7 @@ drev des accumMap = \case
, STArr (SS n) eltty <- typeOf e ->
Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
(weakenExpr (WSink .> WSink) ei1))
sub
@@ -1002,7 +1123,7 @@ drev des accumMap = \case
Ret (binds `BPush` (STArr n (d1 eltty), e1)
`BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
`BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
- (SEYes (SEYes (SENo subtape)))
+ (SEYesR (SEYesR (SENo subtape)))
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
@@ -1030,7 +1151,7 @@ drev des accumMap = \case
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
(EMaybe ext
@@ -1076,7 +1197,7 @@ drev des accumMap = \case
, let tIxN = tTup (sreplicate (SS n) tIx) =
Ret (e0 `BPush` (at, e1)
`BPush` (at', extremum (EVar ext at IZ)))
- (SEYes (SEYes subtape))
+ (SEYesR (SEYesR subtape))
(EVar ext at' IZ)
sub
(EMaybe ext
@@ -1094,16 +1215,17 @@ drev des accumMap = \case
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
data RetScoped env0 sto a s t =
- forall shbinds tapebinds env0Merge.
+ forall shbinds tapebinds contribs.
RetScoped
(Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E (a : env0))) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
-- ^ merge contributions to the _enclosing_ merge environment
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
- (If (s == "discr") (Tup (D2E env0Merge))
- (TPair (Tup (D2E env0Merge)) (D2 a))))
+ (forall sd. Sparse (D2 t) sd
+ -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) (D2 a))))
-- ^ the merge contributions, plus the cotangent to the argument
-- (if there is any)
deriving instance Show (RetScoped env0 sto a s t)
@@ -1118,7 +1240,7 @@ drevScoped des accumMap argty argsto argids expr = case argsto of
SMerge
| Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
case sub of
- SEYes sub' -> RetScoped e0 subtape e1 sub' e2
+ SEYesR sub' -> RetScoped e0 subtape e1 sub' e2
SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty))
SAccum
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
deleted file mode 100644
index d8a71b5..0000000
--- a/src/CHAD/Accum.hs
+++ /dev/null
@@ -1,27 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-module CHAD.Accum where
-
-import AST
-import CHAD.Types
-import Data
-
-
-
-makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators SNil e = e
-makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t =
- makeAccumulators envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e
-
-uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
-uninvertTup SNil _ e = EPair ext e (ENil ext)
-uninvertTup (t `SCons` list) tcore e =
- ELet ext (uninvertTup list (STPair tcore t) e) $
- let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
- in EPair ext
- (EFst ext (EFst ext (EVar ext recT IZ)))
- (EPair ext
- (ESnd ext (EVar ext recT IZ))
- (ESnd ext (EFst ext (EVar ext recT IZ))))
-
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs
index 4c287d7..49ae0e6 100644
--- a/src/CHAD/EnvDescr.hs
+++ b/src/CHAD/EnvDescr.hs
@@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env'
-> r)
-> r
subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k =
+subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
- SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e)
- SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e)
+ SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e)
+ SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e)
+ SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e)
subDescr (des `DPush` (_, _, sto)) (SENo sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
@@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des
select s@SAccum (DPush des (_, _, SDiscr)) = select s des
select s@SMerge (DPush des (_, _, SDiscr)) = select s des
select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des)
+
+selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s)
+selectSub _ DTop = SETop
+selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 974669d..83f013d 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
+import AST.Accum
import AST.Types
import Data
@@ -18,11 +19,11 @@ type family D1 t where
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
D2 (TEither a b) = TLEither (D2 a) (D2 b)
D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TMaybe (TArr n (D2 t))
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
@@ -60,11 +61,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env
d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
-d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
-d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
+d2M (STArr n t) = SMTArr n (d2M t)
d2M (STScal t) = case t of
STI32 -> SMTNil
STI64 -> SMTNil
@@ -116,3 +117,10 @@ chcSetAccum c = c { chcLetArrayAccum = True
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl
diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs
index 9c10421..2712b08 100644
--- a/src/Data/VarMap.hs
+++ b/src/Data/VarMap.hs
@@ -74,7 +74,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env'
subMap subenv =
let bools = let loop :: Subenv env env' -> [Bool]
loop SETop = []
- loop (SEYes sub) = True : loop sub
+ loop (SEYesR sub) = True : loop sub
loop (SENo sub) = False : loop sub
in VS.fromList $ loop subenv
newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools
@@ -89,7 +89,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env
superMap subenv =
let loop :: Subenv env env' -> Int -> [Int]
loop SETop _ = []
- loop (SEYes sub) i = i : loop sub (i+1)
+ loop (SEYesR sub) i = i : loop sub (i+1)
loop (SENo sub) i = loop sub (i+1)
newIndices = VS.fromList $ loop subenv 0