diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-06 22:50:06 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-06 22:50:06 +0200 |
commit | 56056c98b2e3dce65a0e42bce0410c083fd1f8be (patch) | |
tree | 8db2d1be037f8f980c3d1bf76ff9078048f33d63 | |
parent | 7bd37711ffecb7a0e202ecfd717e3a4cbbe6074f (diff) |
WIP mixed static/dynamic sparsitysparse
-rw-r--r-- | chad-fast.cabal | 2 | ||||
-rw-r--r-- | src/AST.hs | 33 | ||||
-rw-r--r-- | src/AST/Accum.hs | 17 | ||||
-rw-r--r-- | src/AST/Bindings.hs | 2 | ||||
-rw-r--r-- | src/AST/Count.hs | 6 | ||||
-rw-r--r-- | src/AST/Env.hs | 58 | ||||
-rw-r--r-- | src/AST/Sparse.hs | 434 | ||||
-rw-r--r-- | src/AST/Types.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 298 | ||||
-rw-r--r-- | src/CHAD/Accum.hs | 27 | ||||
-rw-r--r-- | src/CHAD/EnvDescr.hs | 20 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 16 | ||||
-rw-r--r-- | src/Data/VarMap.hs | 4 |
13 files changed, 747 insertions, 172 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index b0ed639..b8510d2 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -18,13 +18,13 @@ library AST.Count AST.Env AST.Pretty + AST.Sparse AST.SplitLets AST.Types AST.UnMonoid AST.Weaken AST.Weaken.Auto CHAD - CHAD.Accum CHAD.EnvDescr CHAD.Top CHAD.Types @@ -503,11 +503,40 @@ eshapeEmpty (SS n) e = (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) -ezeroD2 :: STy t -> Ex env (D2 t) -ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = ELet ext rhs body + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 03369c8..1101cc0 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,14 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where import AST.Types -import CHAD.Types import Data @@ -75,20 +72,6 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" - -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 745a93b..2310f4b 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -69,7 +69,7 @@ collectBindings = \env -> fst . go env WId where go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) go _ _ SETop = (BTop, WId) - go (ty `SCons` env) w (SEYes sub) = + go (ty `SCons` env) w (SEYesR sub) = let (bs, w') = go env (WPop w) sub in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 0c682c6..03a36f6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -154,7 +154,7 @@ deleteUnused (_ `SCons` env) OccEnd k = deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = deleteUnused env occenv $ \sub -> case count of Zero -> k (SENo sub) - _ -> k (SEYes sub) + _ -> k (SEYesR sub) unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t unsafeWeakenWithSubenv = \sub -> @@ -163,7 +163,7 @@ unsafeWeakenWithSubenv = \sub -> Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") where sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYes _) = Just IZ + sinkViaSubenv IZ (SEYesR _) = Just IZ sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index 4f34166..bc2b9e0 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,59 +1,73 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Env where +import Data.Type.Equality + +import AST.Sparse import AST.Weaken import Data -- | @env'@ is a subset of @env@: each element of @env@ is either included in -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') - SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} -subList :: SList f env -> Subenv env env' -> SList f env' +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: SList f env -> Subenv env env +subenvAll :: IsSubType s => SList f env -> Subenv' s env env subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) +subenvAll (SCons _ env) = SEYes subtFull (subenvAll env) -subenvNone :: SList f env -> Subenv env '[] +subenvNone :: SList f env -> Subenv' s env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) +subenvOnehot :: IsSubType s => SList f env -> Idx env t -> Subenv' s env '[t] +subenvOnehot (SCons _ env) IZ = SEYes subtFull (subenvNone env) subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) subenvOnehot SNil i = case i of {} -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub sinkWithSubenv (SENo sub) = sinkWithSubenv sub -wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs new file mode 100644 index 0000000..09dbc70 --- /dev/null +++ b/src/AST/Sparse.hs @@ -0,0 +1,434 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-} +module AST.Sparse where + +import Data.Kind (Constraint, Type) +import Data.Type.Equality + +import AST + + +data Sparse t t' where + SpDense :: Sparse t t + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpLeft :: Sparse a a' -> Sparse (TLEither a b) a' + SpRight :: Sparse b b' -> Sparse (TLEither a b) b' + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpJust :: Sparse t t' -> Sparse (TMaybe t) t' + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') +deriving instance Show (Sparse t t') + +applySparse :: Sparse t t' -> STy t -> STy t' +applySparse SpDense t = t +applySparse (SpSparse s) t = STMaybe (applySparse s t) +applySparse SpAbsent _ = STNil +applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) +applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) +applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1 +applySparse (SpRight s) (STLEither _ t2) = applySparse s t2 +applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) +applySparse (SpJust s) (STMaybe t) = applySparse s t +applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + + +class IsSubType s where + type IsSubTypeSubject s (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: s a a + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ STy + subtApply = applySparse + + subtTrans SpDense s = s + subtTrans s SpDense = s + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2) + subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2) + subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2) + subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2 + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2) + subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + + subtFull = SpDense + + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This eliminates pointless checks. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. しょうがない。 +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override +sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) +sparsePlusS ST _ t SpAbsent sp2 k = + k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a) +sparsePlusS _ ST t sp1 SpAbsent k = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) +sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) +sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +-- coproducts with partially known arguments: if we have a non-nil +-- always-present coproduct argument, the result is dense, otherwise we +-- introduce sparsity +sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = + sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa -> + k (SpLeft sp3a) + (Inj inj13a) + Noinj + (\x1 x2 -> + elet x1 $ + elcase (weakenExpr WSink x2) + (inj13a (evar IZ)) + (plusa (evar (IS IZ)) (evar IZ)) + (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) + +sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = + sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa -> + k (SpSparse (SpLeft sp3a)) + (Inj $ \x1 -> EJust ext (inj13a x1)) + (Inj $ \x2 -> + elcase x2 + (ENothing ext (applySparse sp3a (fromSMTy ta))) + (EJust ext (inj23a (evar IZ))) + (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr")) + (\x1 x2 -> + elet x1 $ + EJust ext $ + elcase (weakenExpr WSink x2) + (inj13a (evar IZ)) + (plusa (evar (IS IZ)) (evar IZ)) + (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) + +sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k = + sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa) +sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = + sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb -> + k (SpRight sp3b) + (Inj inj13b) + Noinj + (\x1 x2 -> + elet x1 $ + elcase (weakenExpr WSink x2) + (inj13b (evar IZ)) + (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") + (plusb (evar (IS IZ)) (evar IZ))) + +sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = + sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb -> + k (SpSparse (SpRight sp3b)) + (Inj $ \x1 -> EJust ext (inj13b x1)) + (Inj $ \x2 -> + elcase x2 + (ENothing ext (applySparse sp3b (fromSMTy tb))) + (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll") + (EJust ext (inj23b (evar IZ)))) + (\x1 x2 -> + elet x1 $ + EJust ext $ + elcase (weakenExpr WSink x2) + (inj13b (evar IZ)) + (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") + (plusb (evar (IS IZ)) (evar IZ))) + +sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k = + sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb) +sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +-- dense same-branch coproducts simply recurse +sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k = + sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus -> + k (SpLeft sp3) inj1 inj2 plus +sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k = + sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus -> + k (SpRight sp3) inj1 inj2 plus + +-- dense, mismatched coproducts are valid as long as we don't actually invoke +-- plus at runtime (injections are fine) +sparsePlusS SF SF _ SpLeft{} SpRight{} k = + k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr") +sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k = + k (SpRight sp2) Noinj (Inj id) + (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr") +sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k = + k (SpLeft sp1) (Inj id) Noinj + (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll") +sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k = + -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it. + k (SpLEither sp1 sp2) + (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a) + (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b) + (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr") + +sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k = -- the errors are not flipped, but eh + sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) +sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k + +-- maybe with partially known arguments: if we have an always-present Just +-- argument, the result is dense, otherwise we introduce sparsity by weakening +-- to SpMaybe +sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = + sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus -> + k (SpJust sp3) + (Inj inj1) + Noinj + (\a b -> + elet a $ + emaybe (weakenExpr WSink b) + (inj1 (evar IZ)) + (plus (evar (IS IZ)) (evar IZ))) +sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> EJust ext (inj1 a)) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet a $ + emaybe (weakenExpr WSink b) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ)))) + +sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k = + sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus) +sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k + +-- dense same-branch maybes simply recurse +sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus -> + k (SpJust sp3) inj1 inj2 plus + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) +sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k +sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k diff --git a/src/AST/Types.hs b/src/AST/Types.hs index a3b7302..42bfb92 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -5,9 +5,9 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-} module AST.Types where import Data.Int (Int32, Int64) diff --git a/src/CHAD.hs b/src/CHAD.hs index df792ce..b5a9af0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -42,8 +42,8 @@ import AST import AST.Bindings import AST.Count import AST.Env +import AST.Sparse import AST.Weaken.Auto -import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -65,7 +65,7 @@ tapeTy (SCons t ts) = STPair t (tapeTy ts) bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) bindingsCollectTape BTop SETop _ = ENil ext -bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w = EPair ext (EVar ext t (w @> IZ)) (bindingsCollectTape binds sub (w .> WSink)) bindingsCollectTape (BPush binds _) (SENo sub) w = @@ -227,26 +227,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) -> D2Op (TScal a) t @@ -261,11 +272,11 @@ d2op op = case op of -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) STF32 -> float STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -293,7 +304,7 @@ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t @@ -317,44 +328,127 @@ conv2Idx DTop i = case i of {} ------------------------------------ MONOIDS ----------------------------------- -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) - - ------------------------------------- SUBENVS ----------------------------------- +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) +zeroTup SNil _ = ENil ext +zeroTup (t `SCons` env) w = + EPair ext (zeroTup env (WPop w)) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +----------------------------------- SPARSITY ----------------------------------- + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse _ SpDense _ e = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STEither t1 t2) (SpLeft s) epr e = + let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") + in ELInl ext (d2 t2) (expandSparse t1 s epr' e) +expandSparse (STEither t1 t2) (SpRight s) epr e = + let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) + in ELInr ext (d2 t1) (expandSparse t2 s epr' e) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLeft s) epr e = + let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") + in ELInl ext (d2 t2) (expandSparse t1 s epr' e) +expandSparse (STLEither t1 t2) (SpRight s) epr e = + let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) + in ELInr ext (d2 t1) (expandSparse t2 s epr' e) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal sty) _ _ _ = case sty of {} -- SpDense and SpSparse handled already +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +sparsePlus + :: SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) + -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2 + -> (forall env3. SubenvS (D2E env) env3 + -> SubenvS env3 env1 + -> SubenvS env3 env2 + -> (Ex exenv (Tup env1) + -> Ex exenv (Tup env2) + -> Ex exenv (Tup env3)) -> r) -> r subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = +subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> + k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 -> ELet ext e1 $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) (weakenExpr WSink e2)) (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = +subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> + k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 -> ELet ext e2 $ EPair ext (pl (weakenExpr WSink e1) (EFst ext (EVar ext (typeOf e2) IZ))) (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = +subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> + k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 -> ELet ext e1 $ ELet ext (weakenExpr WSink e2) $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) @@ -363,22 +457,44 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = (ESnd ext (EVar ext (typeOf e1) (IS IZ))) (ESnd ext (EVar ext (typeOf e2) IZ))) -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" +assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + fromArrayValId :: Maybe (ValId t) -> Maybe Int fromArrayValId (Just (VIArr i _)) = Just i fromArrayValId _ = Nothing @@ -422,7 +538,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of k (storepl `DPush` (t, vid, SAccum)) envpro prosub - (SEYes accrevsub) + (SEYesR accrevsub) (VarMap.sink1 accumMap) (\shbinds -> autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) @@ -449,7 +565,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SAccum)) (t `SCons` envpro) - (SEYes prosub) + (SEYesR prosub) (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of @@ -499,19 +615,21 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- data Ret env0 sto t = - forall shbinds tapebinds env0Merge. + forall shbinds tapebinds contribs. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (forall sd. Sparse (D2 t) sd + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) deriving instance Show (Ret env0 sto t) data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. + forall contribs. RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (forall sd. Sparse (D2 t) sd + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) deriving instance Show (RetPair env0 sto env shbinds tapebinds t) data Rets env0 sto env list = @@ -569,18 +687,24 @@ freezeRet :: Descr env sto freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (subList (d2e (select SMerge descr)) sub) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) + (ELet ext (weakenExpr (autoWeak library (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) ---------------------------- THE CHAD TRANSFORMATION --------------------------- @@ -596,21 +720,21 @@ drev des accumMap = \case Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) + (subenvOnehot (d2e (select SMerge des)) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) Idx2Di _ -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) ELet _ (rhs :: Expr _ _ a) body @@ -621,7 +745,7 @@ drev des accumMap = \case , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in + let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) (weakenExpr wbody0' body1) @@ -637,7 +761,7 @@ drev des accumMap = \case (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b @@ -649,16 +773,13 @@ drev des accumMap = \case subtape (EPair ext a1 b1) subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ)) + (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ)) EFst _ e | Ret e0 subtape e1 sub e2 <- drev des accumMap e @@ -732,7 +853,7 @@ drev des accumMap = \case ECase ext e1 (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) + (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut (ELet ext @@ -801,7 +922,7 @@ drev des accumMap = \case (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) + (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) @@ -816,7 +937,7 @@ drev des accumMap = \case `BPush` (typeOf b1, weakenExpr WSink b1) `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) + (SEYesR (SENo (SENo (SENo subtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ @@ -865,7 +986,7 @@ drev des accumMap = \case , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in + let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in @@ -894,7 +1015,7 @@ drev des accumMap = \case in EPair ext (weakenExpr w e1) (collectexpr w))) `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) + (SEYesR (SENo (SEYesR SETop))) (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvCompose subMergeUsed proSub) @@ -981,7 +1102,7 @@ drev des accumMap = \case , STArr (SS n) eltty <- typeOf e -> Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) + (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) sub @@ -1002,7 +1123,7 @@ drev des accumMap = \case Ret (binds `BPush` (STArr n (d1 eltty), e1) `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) + (SEYesR (SEYesR (SENo subtape))) (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub @@ -1030,7 +1151,7 @@ drev des accumMap = \case , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) + (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub (EMaybe ext @@ -1076,7 +1197,7 @@ drev des accumMap = \case , let tIxN = tTup (sreplicate (SS n) tIx) = Ret (e0 `BPush` (at, e1) `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) + (SEYesR (SEYesR subtape)) (EVar ext at' IZ) sub (EMaybe ext @@ -1094,16 +1215,17 @@ drev des accumMap = \case data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. + forall shbinds tapebinds contribs. RetScoped (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) + (SubenvS (D2E (Select env0 sto "merge")) contribs) -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) + (forall sd. Sparse (D2 t) sd + -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) (D2 a)))) -- ^ the merge contributions, plus the cotangent to the argument -- (if there is any) deriving instance Show (RetScoped env0 sto a s t) @@ -1118,7 +1240,7 @@ drevScoped des accumMap argty argsto argids expr = case argsto of SMerge | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 + SEYesR sub' -> RetScoped e0 subtape e1 sub' e2 SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) SAccum diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs deleted file mode 100644 index d8a71b5..0000000 --- a/src/CHAD/Accum.hs +++ /dev/null @@ -1,27 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -module CHAD.Accum where - -import AST -import CHAD.Types -import Data - - - -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = - makeAccumulators envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs index 4c287d7..49ae0e6 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/EnvDescr.hs @@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of @@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des select s@SAccum (DPush des (_, _, SDiscr)) = select s des select s@SMerge (DPush des (_, _, SDiscr)) = select s des select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 974669d..83f013d 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -3,6 +3,7 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Types where +import AST.Accum import AST.Types import Data @@ -18,11 +19,11 @@ type family D1 t where type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) + D2 (TPair a b) = TPair (D2 a) (D2 b) D2 (TEither a b) = TLEither (D2 a) (D2 b) D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TMaybe (TArr n (D2 t)) + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where @@ -60,11 +61,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env d2M :: STy t -> SMTy (D2 t) d2M STNil = SMTNil -d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STPair a b) = SMTPair (d2M a) (d2M b) d2M (STEither a b) = SMTLEither (d2M a) (d2M b) d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2M (STMaybe t) = SMTMaybe (d2M t) -d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STArr n t) = SMTArr n (d2M t) d2M (STScal t) = case t of STI32 -> SMTNil STI64 -> SMTNil @@ -116,3 +117,10 @@ chcSetAccum c = c { chcLetArrayAccum = True indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs index 9c10421..2712b08 100644 --- a/src/Data/VarMap.hs +++ b/src/Data/VarMap.hs @@ -74,7 +74,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' subMap subenv = let bools = let loop :: Subenv env env' -> [Bool] loop SETop = [] - loop (SEYes sub) = True : loop sub + loop (SEYesR sub) = True : loop sub loop (SENo sub) = False : loop sub in VS.fromList $ loop subenv newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools @@ -89,7 +89,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env superMap subenv = let loop :: Subenv env env' -> Int -> [Int] loop SETop _ = [] - loop (SEYes sub) i = i : loop sub (i+1) + loop (SEYesR sub) i = i : loop sub (i+1) loop (SENo sub) i = loop sub (i+1) newIndices = VS.fromList $ loop subenv 0 |