diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 33 | ||||
| -rw-r--r-- | src/AST/Accum.hs | 17 | ||||
| -rw-r--r-- | src/AST/Bindings.hs | 2 | ||||
| -rw-r--r-- | src/AST/Count.hs | 6 | ||||
| -rw-r--r-- | src/AST/Env.hs | 58 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 434 | ||||
| -rw-r--r-- | src/AST/Types.hs | 2 | ||||
| -rw-r--r-- | src/CHAD.hs | 294 | ||||
| -rw-r--r-- | src/CHAD/Accum.hs | 27 | ||||
| -rw-r--r-- | src/CHAD/EnvDescr.hs | 20 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 16 | ||||
| -rw-r--r-- | src/Data/VarMap.hs | 4 | 
12 files changed, 744 insertions, 169 deletions
| @@ -503,11 +503,40 @@ eshapeEmpty (SS n) e =                                        (EConst ext STI64 0)))        (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) -ezeroD2 :: STy t -> Ex env (D2 t) -ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi  -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil  -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea  -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)  -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k = +  elet e $ +    k WSink +      (EFst ext (evar IZ)) +      (ESnd ext (evar IZ)) + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body +  | Dict <- styKnown (typeOf rhs) +  = ELet ext rhs body + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b +  | STMaybe t <- typeOf e +  , Dict <- styKnown t +  = EMaybe ext a b e + +elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c +elcase e a b c +  | STLEither t1 t2 <- typeOf e +  , Dict <- styKnown t1 +  , Dict <- styKnown t2 +  = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 03369c8..1101cc0 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,14 +1,11 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-}  {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE UndecidableInstances #-}  module AST.Accum where  import AST.Types -import CHAD.Types  import Data @@ -75,20 +72,6 @@ tZeroInfo (SMTMaybe _) = STNil  tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)  tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" -  -- -- | Additional info needed for accumulation. This is empty unless there is  -- -- sparsity in the monoid.  -- type family AccumInfo t where diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 745a93b..2310f4b 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -69,7 +69,7 @@ collectBindings = \env -> fst . go env WId    where      go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)      go _ _ SETop = (BTop, WId) -    go (ty `SCons` env) w (SEYes sub) = +    go (ty `SCons` env) w (SEYesR sub) =        let (bs, w') = go env (WPop w) sub        in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')      go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 0c682c6..03a36f6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -154,7 +154,7 @@ deleteUnused (_ `SCons` env) OccEnd k =  deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =    deleteUnused env occenv $ \sub ->      case count of Zero -> k (SENo sub) -                  _    -> k (SEYes sub) +                  _    -> k (SEYesR sub)  unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t  unsafeWeakenWithSubenv = \sub -> @@ -163,7 +163,7 @@ unsafeWeakenWithSubenv = \sub ->                       Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")    where      sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) -    sinkViaSubenv IZ (SEYes _) = Just IZ +    sinkViaSubenv IZ (SEYesR _) = Just IZ      sinkViaSubenv IZ (SENo _) = Nothing -    sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub +    sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub      sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index 4f34166..bc2b9e0 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,59 +1,73 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE TypeOperators #-}  module AST.Env where +import Data.Type.Equality + +import AST.Sparse  import AST.Weaken  import Data  -- | @env'@ is a subset of @env@: each element of @env@ is either included in  -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where -  SETop :: Subenv '[] '[] -  SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') -  SENo  :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') +data Subenv' s env env' where +  SETop :: Subenv' s '[] '[] +  SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') +  SENo  :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () +               => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') +               => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} -subList :: SList f env -> Subenv env env' -> SList f env' +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'  subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)  subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: SList f env -> Subenv env env +subenvAll :: IsSubType s => SList f env -> Subenv' s env env  subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) +subenvAll (SCons _ env) = SEYes subtFull (subenvAll env) -subenvNone :: SList f env -> Subenv env '[] +subenvNone :: SList f env -> Subenv' s env '[]  subenvNone SNil = SETop  subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) +subenvOnehot :: IsSubType s => SList f env -> Idx env t -> Subenv' s env '[t] +subenvOnehot (SCons _ env) IZ = SEYes subtFull (subenvNone env)  subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)  subenvOnehot SNil i = case i of {} -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3  subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)  subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')  subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)  subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0  sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub  sinkWithSubenv (SENo sub) = sinkWithSubenv sub -wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env  wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)  wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs new file mode 100644 index 0000000..09dbc70 --- /dev/null +++ b/src/AST/Sparse.hs @@ -0,0 +1,434 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-} +module AST.Sparse where + +import Data.Kind (Constraint, Type) +import Data.Type.Equality + +import AST + + +data Sparse t t' where +  SpDense :: Sparse t t +  SpSparse :: Sparse t t' -> Sparse t (TMaybe t') +  SpAbsent :: Sparse t TNil + +  SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') +  SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') +  SpLeft :: Sparse a a' -> Sparse (TLEither a b) a' +  SpRight :: Sparse b b' -> Sparse (TLEither a b) b' +  SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') +  SpJust :: Sparse t t' -> Sparse (TMaybe t) t' +  SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') +deriving instance Show (Sparse t t') + +applySparse :: Sparse t t' -> STy t -> STy t' +applySparse SpDense t = t +applySparse (SpSparse s) t = STMaybe (applySparse s t) +applySparse SpAbsent _ = STNil +applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) +applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) +applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1 +applySparse (SpRight s) (STLEither _ t2) = applySparse s t2 +applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) +applySparse (SpJust s) (STMaybe t) = applySparse s t +applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + + +class IsSubType s where +  type IsSubTypeSubject s (f :: k -> Type) :: Constraint +  subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' +  subtTrans :: s a b -> s b c -> s a c +  subtFull :: s a a + +instance IsSubType (:~:) where +  type IsSubTypeSubject (:~:) f = () +  subtApply = gcastWith +  subtTrans = trans +  subtFull = Refl + +instance IsSubType Sparse where +  type IsSubTypeSubject Sparse f = f ~ STy +  subtApply = applySparse + +  subtTrans SpDense s = s +  subtTrans s SpDense = s +  subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) +  subtTrans _ SpAbsent = SpAbsent +  subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) +  subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) +  subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2) +  subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2) +  subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2) +  subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2) +  subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) +  subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2 +  subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) +  subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2) +  subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2) +  subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + +  subtFull = SpDense + + +data SBool b where +  SF :: SBool False +  ST :: SBool True +deriving instance Show (SBool b) + +data Injection sp a b where +  -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that +  -- 'sparsePlusS' can provide injections even if the caller doesn't require +  -- them. This eliminates pointless checks. +  Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b +  Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 +         -> ((forall e. Ex e a1 -> Ex e b1) +             -> (forall e. Ex e a2 -> Ex e b2) +             -> (forall e'. Ex e' a' -> Ex e' b')) +         -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. しょうがない。 +sparsePlusS +  :: SBool inj1 -> SBool inj2 +  -> SMTy t -> Sparse t t1 -> Sparse t t2 +  -> (forall t3. Sparse t t3 +              -> Injection inj1 t1 t3  -- only available if first injection is requested (second argument may be absent) +              -> Injection inj2 t2 t3  -- only available if second injection is requested (first argument may be absent) +              -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) +              -> r) +  -> r +-- nil override +sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = +  sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = +  sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> +    k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = +  let ta = applySparse sp1 (fromSMTy t) in +  sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 +      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) +      minj2 +      (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = +  let tb = applySparse sp2 (fromSMTy t) in +  sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> +    k sp3 +      minj1 +      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +      (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = +  let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in +  sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 +      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) +      minj2 +      (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = +  let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in +  sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> +    k sp3 +      minj1 +      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) +      (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = +  let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in +  sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 +      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) +      minj2 +      (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = +  let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in +  sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> +    k sp3 +      minj1 +      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) +      (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) +sparsePlusS ST _ t SpAbsent sp2 k = +  k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a) +sparsePlusS _ ST t sp1 SpAbsent k = +  k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpSparse sp3) +      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) +      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (emaybe (evar IZ) +              (ENothing ext (applySparse sp3 (fromSMTy t))) +              (EJust ext (inj2 (evar IZ)))) +            (emaybe (evar (IS IZ)) +              (EJust ext (inj1 (evar IZ))) +              (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = +  sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> +    k sp3 Noinj (Inj inj2) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (inj2 (evar IZ)) +            (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpSparse sp3) +      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) +      (Inj $ \b -> EJust ext (inj2 b)) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (EJust ext (inj2 (evar IZ))) +            (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = +  sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> +    k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = +  sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> +  sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> +    k (SpPair sp3a sp3b) +      (withInj2 minj13a minj13b $ \inj13a inj13b -> +        \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) +      (withInj2 minj23a minj23b $ \inj23a inj23b -> +        \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) +      (\x1 x2 -> +        eunPair x1 $ \w1 x1a x1b -> +        eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> +          EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) +sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = +  sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> +  sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> +    let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) +        inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) +        inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) +    in +    k (SpLEither sp3a sp3b) +      (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) +      (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) +      (\x1 x2 -> +        elet x2 $ +          elcase (weakenExpr WSink x1) +            (elcase (evar IZ) +              nil +              (inl (inj23a (evar IZ))) +              (inr (inj23b (evar IZ)))) +            (elcase (evar (IS IZ)) +              (inl (inj13a (evar IZ))) +              (inl (plusa (evar (IS IZ)) (evar IZ))) +              (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) +            (elcase (evar (IS IZ)) +              (inr (inj13b (evar IZ))) +              (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") +              (inr (plusb (evar (IS IZ)) (evar IZ))))) +sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +-- coproducts with partially known arguments: if we have a non-nil +-- always-present coproduct argument, the result is dense, otherwise we +-- introduce sparsity +sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = +  sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa -> +    k (SpLeft sp3a) +      (Inj inj13a) +      Noinj +      (\x1 x2 -> +        elet x1 $ +          elcase (weakenExpr WSink x2) +            (inj13a (evar IZ)) +            (plusa (evar (IS IZ)) (evar IZ)) +            (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) + +sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = +  sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa -> +    k (SpSparse (SpLeft sp3a)) +      (Inj $ \x1 -> EJust ext (inj13a x1)) +      (Inj $ \x2 -> +        elcase x2 +          (ENothing ext (applySparse sp3a (fromSMTy ta))) +          (EJust ext (inj23a (evar IZ))) +          (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr")) +      (\x1 x2 -> +        elet x1 $ +          EJust ext $ +            elcase (weakenExpr WSink x2) +              (inj13a (evar IZ)) +              (plusa (evar (IS IZ)) (evar IZ)) +              (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) + +sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k = +  sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa) +sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = +  sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb -> +    k (SpRight sp3b) +      (Inj inj13b) +      Noinj +      (\x1 x2 -> +        elet x1 $ +          elcase (weakenExpr WSink x2) +            (inj13b (evar IZ)) +            (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") +            (plusb (evar (IS IZ)) (evar IZ))) + +sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = +  sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb -> +    k (SpSparse (SpRight sp3b)) +      (Inj $ \x1 -> EJust ext (inj13b x1)) +      (Inj $ \x2 -> +        elcase x2 +          (ENothing ext (applySparse sp3b (fromSMTy tb))) +          (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll") +          (EJust ext (inj23b (evar IZ)))) +      (\x1 x2 -> +        elet x1 $ +          EJust ext $ +            elcase (weakenExpr WSink x2) +              (inj13b (evar IZ)) +              (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") +              (plusb (evar (IS IZ)) (evar IZ))) + +sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k = +  sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb) +sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +-- dense same-branch coproducts simply recurse +sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k = +  sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus -> +    k (SpLeft sp3) inj1 inj2 plus +sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k = +  sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus -> +    k (SpRight sp3) inj1 inj2 plus + +-- dense, mismatched coproducts are valid as long as we don't actually invoke +-- plus at runtime (injections are fine) +sparsePlusS SF SF _ SpLeft{} SpRight{} k = +  k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr") +sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k = +  k (SpRight sp2) Noinj (Inj id) +    (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr") +sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k = +  k (SpLeft sp1) (Inj id) Noinj +    (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll") +sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k = +  -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it. +  k (SpLEither sp1 sp2) +    (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a) +    (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b) +    (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr") + +sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k =  -- the errors are not flipped, but eh +  sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpMaybe sp3) +      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) +      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (emaybe (evar IZ) +              (ENothing ext (applySparse sp3 (fromSMTy t))) +              (EJust ext (inj2 (evar IZ)))) +            (emaybe (evar (IS IZ)) +              (EJust ext (inj1 (evar IZ))) +              (EJust ext (plus (evar (IS IZ)) (evar IZ))))) +sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k + +-- maybe with partially known arguments: if we have an always-present Just +-- argument, the result is dense, otherwise we introduce sparsity by weakening +-- to SpMaybe +sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = +  sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus -> +    k (SpJust sp3) +      (Inj inj1) +      Noinj +      (\a b -> +        elet a $ +          emaybe (weakenExpr WSink b) +            (inj1 (evar IZ)) +            (plus (evar (IS IZ)) (evar IZ))) +sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpMaybe sp3) +      (Inj $ \a -> EJust ext (inj1 a)) +      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) +      (\a b -> +        elet a $ +          emaybe (weakenExpr WSink b) +            (EJust ext (inj1 (evar IZ))) +            (EJust ext (plus (evar (IS IZ)) (evar IZ)))) + +sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k = +  sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus) +sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k + +-- dense same-branch maybes simply recurse +sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k = +  sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus -> +    k (SpJust sp3) inj1 inj2 plus + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = +  sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> +    k (SpArr sp3) +      (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) +      (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) +      (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) +                      (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) +sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k +sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k diff --git a/src/AST/Types.hs b/src/AST/Types.hs index a3b7302..42bfb92 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -5,9 +5,9 @@  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-}  module AST.Types where  import Data.Int (Int32, Int64) diff --git a/src/CHAD.hs b/src/CHAD.hs index df792ce..b5a9af0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -42,8 +42,8 @@ import AST  import AST.Bindings  import AST.Count  import AST.Env +import AST.Sparse  import AST.Weaken.Auto -import CHAD.Accum  import CHAD.EnvDescr  import CHAD.Types  import Data @@ -65,7 +65,7 @@ tapeTy (SCons t ts) = STPair t (tapeTy ts)  bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds                  -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)  bindingsCollectTape BTop SETop _ = ENil ext -bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w =    EPair ext (EVar ext t (w @> IZ))              (bindingsCollectTape binds sub (w .> WSink))  bindingsCollectTape (BPush binds _) (SENo sub) w = @@ -227,26 +227,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))  d2op :: SOp a t -> D2Op a t  d2op op = case op of -  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) +  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d    OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> -    EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) -                         (EOp ext (OMul t) (EPair ext (EFst ext e) d))) +    EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) +              (EOp ext (OMul t) (EPair ext (EFst ext e) d))    ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d -  OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) -  OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) -  OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) +  OLt t -> Linear $ \_ -> pairZero t +  OLe t -> Linear $ \_ -> pairZero t +  OEq t -> Linear $ \_ -> pairZero t    ONot -> Linear $ \_ -> ENil ext -  OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) -  OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +  OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) +  OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)    OIf -> Linear $ \_ -> ENil ext -  ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 +  ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)    OToFl64 -> Linear $ \_ -> ENil ext    ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)    OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)    OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) -  OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) -  OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) +  OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) +  OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)    where +    pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) +    pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) +                                     (EZero ext (d2M (STScal t)) (ENil ext)) +      where +        ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r +        ziNil STI32 k = k +        ziNil STI64 k = k +        ziNil STF32 k = k +        ziNil STF64 k = k +        ziNil STBool k = k +      d2opUnArrangeInt :: SScalTy a                       -> (D2s a ~ TScal a => D2Op (TScal a) t)                       -> D2Op (TScal a) t @@ -261,11 +272,11 @@ d2op op = case op of                        -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)                        -> D2Op (TPair (TScal a) (TScal a)) t      d2opBinArrangeInt ty float = case ty of -      STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) -      STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +      STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) +      STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)        STF32 -> float        STF64 -> float -      STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +      STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)      floatingD2 :: ScalIsFloating a ~ True                 => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -293,7 +304,7 @@ conv1Idx (IS i) = IS (conv1Idx i)  data Idx2 env sto t    = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) -  | Idx2Me (Idx (Select env sto "merge") t) +  | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))    | Idx2Di (Idx (Select env sto "discr") t)  conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t @@ -317,44 +328,127 @@ conv2Idx DTop i = case i of {}  ------------------------------------ MONOIDS ----------------------------------- -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = +  eunPair e $ \_ e1 e2 -> +    EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) +zeroTup SNil _ = ENil ext +zeroTup (t `SCons` env) w = +  EPair ext (zeroTup env (WPop w)) +            (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +----------------------------------- SPARSITY ----------------------------------- +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) ------------------------------------- SUBENVS ----------------------------------- +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse _ SpDense _ e = e +expandSparse t (SpSparse sp) epr e = +  EMaybe ext +    (EZero ext (d2M t) (d2zeroInfo t epr)) +    (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) +    e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = +  eunPair epr $ \w1 epr1 epr2 -> +  eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> +    EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) +              (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = +  ELCase ext e +    (EZero ext (d2M (STEither t1 t2)) (ENil ext)) +    (ECase ext (weakenExpr WSink epr) +       (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) +       (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) +    (ECase ext (weakenExpr WSink epr) +       (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") +       (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STEither t1 t2) (SpLeft s) epr e = +  let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") +  in ELInl ext (d2 t2) (expandSparse t1 s epr' e) +expandSparse (STEither t1 t2) (SpRight s) epr e = +  let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) +  in ELInr ext (d2 t1) (expandSparse t2 s epr' e) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = +  ELCase ext e +    (EZero ext (d2M (STEither t1 t2)) (ENil ext)) +    (ELCase ext (weakenExpr WSink epr) +       (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") +       (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) +       (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) +    (ELCase ext (weakenExpr WSink epr) +       (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") +       (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") +       (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLeft s) epr e = +  let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") +  in ELInl ext (d2 t2) (expandSparse t1 s epr' e) +expandSparse (STLEither t1 t2) (SpRight s) epr e = +  let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) +  in ELInr ext (d2 t1) (expandSparse t2 s epr' e) +expandSparse (STMaybe t) (SpMaybe s) epr e = +  EMaybe ext +    (ENothing ext (d2 t)) +    (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr +     in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) +    e +expandSparse (STArr _ t) (SpArr s) epr e = +  ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal sty) _ _ _ = case sty of {}  -- SpDense and SpSparse handled already +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +sparsePlus +  :: SMTy t -> Sparse t t1 -> Sparse t t2 +  -> (forall t3. Sparse t t3 +              -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) +              -> r) +  -> r +sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus  subenvPlus :: SList STy env -           -> Subenv env env1 -> Subenv env env2 -           -> (forall env3. Subenv env env3 -                         -> Subenv env3 env1 -                         -> Subenv env3 env2 -                         -> (Ex exenv (Tup (D2E env1)) -                             -> Ex exenv (Tup (D2E env2)) -                             -> Ex exenv (Tup (D2E env3))) +           -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2 +           -> (forall env3. SubenvS (D2E env) env3 +                         -> SubenvS env3 env1 +                         -> SubenvS env3 env2 +                         -> (Ex exenv (Tup env1) +                             -> Ex exenv (Tup env2) +                             -> Ex exenv (Tup env3))                           -> r)             -> r  subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)  subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =    subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->      k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = +subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =    subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> +    k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 ->        ELet ext e1 $          EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))                        (weakenExpr WSink e2))                    (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = +subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k =    subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> +    k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 ->        ELet ext e2 $          EPair ext (pl (weakenExpr WSink e1)                        (EFst ext (EVar ext (typeOf e2) IZ)))                    (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = +subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =    subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> +    k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 ->        ELet ext e1 $        ELet ext (weakenExpr WSink e2) $          EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) @@ -363,22 +457,44 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =                      (ESnd ext (EVar ext (typeOf e1) (IS IZ)))                      (ESnd ext (EVar ext (typeOf e2) IZ))) -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = -  ELet ext e $ -    let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ -    in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs +                  -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = +  eunPair e $ \w1 e1 e2 -> +    EPair ext +      (expandSubenvZeros (w1 .> WPop w) ts sub e1) +      (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = +  EPair ext +    (expandSubenvZeros (WPop w) ts sub e) +    (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))  assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]  assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl  assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" +assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty"  --------------------------------- ACCUMULATORS --------------------------------- +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = +  makeAccumulators (WPop w) envpro $ +    EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = +  ELet ext (uninvertTup list (STPair tcore t) e) $ +    let recT = STPair (STPair tcore t) (tTup list)  -- type of the RHS of that let binding +    in EPair ext +         (EFst ext (EFst ext (EVar ext recT IZ))) +         (EPair ext +            (ESnd ext (EVar ext recT IZ)) +            (ESnd ext (EFst ext (EVar ext recT IZ)))) +  fromArrayValId :: Maybe (ValId t) -> Maybe Int  fromArrayValId (Just (VIArr i _)) = Just i  fromArrayValId _ = Nothing @@ -422,7 +538,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of        k (storepl `DPush` (t, vid, SAccum))          envpro          prosub -        (SEYes accrevsub) +        (SEYesR accrevsub)          (VarMap.sink1 accumMap)          (\shbinds ->            autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) @@ -449,7 +565,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of        accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->          k (storepl `DPush` (t, vid, SAccum))            (t `SCons` envpro) -          (SEYes prosub) +          (SEYesR prosub)            (SENo accrevsub)            (let accumMap' = VarMap.sink1 accumMap             in case fromArrayValId vid of @@ -499,19 +615,21 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of  ---------------------------- RETURN TRIPLE FROM CHAD ---------------------------  data Ret env0 sto t = -  forall shbinds tapebinds env0Merge. +  forall shbinds tapebinds contribs.      Ret (Bindings Ex (D1E env0) shbinds)  -- shared binds          (Subenv shbinds tapebinds)          (Ex (Append shbinds (D1E env0)) (D1 t)) -        (Subenv (Select env0 sto "merge") env0Merge) -        (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) +        (SubenvS (D2E (Select env0 sto "merge")) contribs) +        (forall sd. Sparse (D2 t) sd +                 -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))  deriving instance Show (Ret env0 sto t)  data RetPair env0 sto env shbinds tapebinds t = -  forall env0Merge. +  forall contribs.      RetPair (Ex (Append shbinds env) (D1 t)) -            (Subenv (Select env0 sto "merge") env0Merge) -            (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) +            (SubenvS (D2E (Select env0 sto "merge")) contribs) +            (forall sd. Sparse (D2 t) sd +                     -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))  deriving instance Show (RetPair env0 sto env shbinds tapebinds t)  data Rets env0 sto env list = @@ -569,18 +687,24 @@ freezeRet :: Descr env sto  freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =    let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0        e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 +      tContribs = tTup (subList (d2e (select SMerge descr)) sub) +      library = #d (auto1 @(D2 t)) +                &. #tape (subList (bindingsBinds e0) subtape) +                &. #shbinds (bindingsBinds e0) +                &. #d2ace (d2ace (select SAccum descr)) +                &. #tl (desD1E descr) +                &. #contribs (SCons tContribs SNil)    in letBinds e0' $         EPair ext           (weakenExpr wInsertD2Ac e1) -         (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                          &. #tape (subList (bindingsBinds e0) subtape) -                                          &. #shbinds (bindingsBinds e0) -                                          &. #d2ace (d2ace (select SAccum descr)) -                                          &. #tl (desD1E descr)) +         (ELet ext (weakenExpr (autoWeak library                                           (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)                                           (#shbinds :++: #d :++: #d2ace :++: #tl))                        e2') $ -          expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) +          expandSubenvZeros +            (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) +             .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) +            (select SMerge descr) sub (EVar ext tContribs IZ))  ---------------------------- THE CHAD TRANSFORMATION --------------------------- @@ -596,21 +720,21 @@ drev des accumMap = \case          Ret BTop              SETop              (EVar ext (d1 t) (conv1Idx i)) -            (subenvNone (select SMerge des)) +            (subenvNone (d2e (select SMerge des)))              (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))        Idx2Me tupI ->          Ret BTop              SETop              (EVar ext (d1 t) (conv1Idx i)) -            (subenvOnehot (select SMerge des) tupI) +            (subenvOnehot (d2e (select SMerge des)) tupI)              (EPair ext (ENil ext) (EVar ext (d2 t) IZ))        Idx2Di _ ->          Ret BTop              SETop              (EVar ext (d1 t) (conv1Idx i)) -            (subenvNone (select SMerge des)) +            (subenvNone (d2e (select SMerge des)))              (ENil ext)    ELet _ (rhs :: Expr _ _ a) body @@ -621,7 +745,7 @@ drev des accumMap = \case      , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)      , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->      subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> -    let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in +    let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in      Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')          (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)          (weakenExpr wbody0' body1) @@ -637,7 +761,7 @@ drev des accumMap = \case             (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $               weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $           plus_RHS_Body -           (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) +           (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ)             (EFst ext (EVar ext bodyResType (IS IZ))))    EPair _ a b @@ -649,16 +773,13 @@ drev des accumMap = \case          subtape          (EPair ext a1 b1)          subBoth -        (EMaybe ext -           (zeroTup (subList (select SMerge des) subBoth)) -           (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) -                      (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ -            ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) -                      (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ -            plus_A_B -             (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) -             (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) -           (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) +        (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) +                   (weakenExpr (WCopy WSink) a2)) $ +         ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) +                   (weakenExpr (WCopy (WSink .> WSink)) b2)) $ +         plus_A_B +           (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ)) +           (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ))    EFst _ e      | Ret e0 subtape e1 sub e2 <- drev des accumMap e @@ -732,7 +853,7 @@ drev des accumMap = \case              ECase ext e1                (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))                (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) -        (SEYes subtapeE) +        (SEYesR subtapeE)          (EFst ext (EVar ext tPrimal IZ))          subOut          (ELet ext @@ -801,7 +922,7 @@ drev des accumMap = \case                 (weakenExpr (WCopy WSink) e2))        Nonlinear d2opfun ->          Ret (e0 `BPush` (d1 (typeOf e), e1)) -            (SEYes subtape) +            (SEYesR subtape)              (d1op op $ EVar ext (d1 (typeOf e)) IZ)              sub              (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) @@ -816,7 +937,7 @@ drev des accumMap = \case                 `BPush` (typeOf b1, weakenExpr WSink b1)                 `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))                 `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) -        (SEYes (SENo (SENo (SENo subtape)))) +        (SEYesR (SENo (SENo (SENo subtape))))          (EFst ext (EVar ext (typeOf pr) (IS IZ)))          bsub          (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ @@ -865,7 +986,7 @@ drev des accumMap = \case      , shty :: STy shty <- tTup (sreplicate ndim tIx)      , Refl <- indexTupD1Id ndim ->      deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> -    let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in +    let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in      subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->      accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->      let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in @@ -894,7 +1015,7 @@ drev des accumMap = \case                              in EPair ext (weakenExpr w e1) (collectexpr w)))                `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))                                                 (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) -        (SEYes (SENo (SEYes SETop))) +        (SEYesR (SENo (SEYesR SETop)))          (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))                (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))          (subenvCompose subMergeUsed proSub) @@ -981,7 +1102,7 @@ drev des accumMap = \case      , STArr (SS n) eltty <- typeOf e ->      Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)                 `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) -        (SEYes (SENo subtape)) +        (SEYesR (SENo subtape))          (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))                     (weakenExpr (WSink .> WSink) ei1))          sub @@ -1002,7 +1123,7 @@ drev des accumMap = \case      Ret (binds `BPush` (STArr n (d1 eltty), e1)                 `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))                 `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) -        (SEYes (SEYes (SENo subtape))) +        (SEYesR (SEYesR (SENo subtape)))          (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))                    (EVar ext (tTup (sreplicate n tIx)) IZ))          sub @@ -1030,7 +1151,7 @@ drev des accumMap = \case      , STArr (SS n) t <- typeOf e ->      Ret (e0 `BPush` (STArr (SS n) t, e1)              `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) -        (SEYes (SENo subtape)) +        (SEYesR (SENo subtape))          (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))          sub          (EMaybe ext @@ -1076,7 +1197,7 @@ drev des accumMap = \case        , let tIxN = tTup (sreplicate (SS n) tIx) =        Ret (e0 `BPush` (at, e1)                `BPush` (at', extremum (EVar ext at IZ))) -          (SEYes (SEYes subtape)) +          (SEYesR (SEYesR subtape))            (EVar ext at' IZ)            sub            (EMaybe ext @@ -1094,16 +1215,17 @@ drev des accumMap = \case  data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)  data RetScoped env0 sto a s t = -  forall shbinds tapebinds env0Merge. +  forall shbinds tapebinds contribs.      RetScoped          (Bindings Ex (D1E (a : env0)) shbinds)  -- shared binds          (Subenv shbinds tapebinds)          (Ex (Append shbinds (D1E (a : env0))) (D1 t)) -        (Subenv (Select env0 sto "merge") env0Merge) +        (SubenvS (D2E (Select env0 sto "merge")) contribs)             -- ^ merge contributions to the _enclosing_ merge environment -        (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) -           (If (s == "discr") (Tup (D2E env0Merge)) -                              (TPair (Tup (D2E env0Merge)) (D2 a)))) +        (forall sd. Sparse (D2 t) sd +                 -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) +                       (If (s == "discr") (Tup contribs) +                                          (TPair (Tup contribs) (D2 a))))            -- ^ the merge contributions, plus the cotangent to the argument            -- (if there is any)  deriving instance Show (RetScoped env0 sto a s t) @@ -1118,7 +1240,7 @@ drevScoped des accumMap argty argsto argids expr = case argsto of    SMerge      | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->          case sub of -          SEYes sub' -> RetScoped e0 subtape e1 sub' e2 +          SEYesR sub' -> RetScoped e0 subtape e1 sub' e2            SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty))    SAccum diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs deleted file mode 100644 index d8a71b5..0000000 --- a/src/CHAD/Accum.hs +++ /dev/null @@ -1,27 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -module CHAD.Accum where - -import AST -import CHAD.Types -import Data - - - -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = -  makeAccumulators envpro $ -    EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = -  ELet ext (uninvertTup list (STPair tcore t) e) $ -    let recT = STPair (STPair tcore t) (tTup list)  -- type of the RHS of that let binding -    in EPair ext -         (EFst ext (EFst ext (EVar ext recT IZ))) -         (EPair ext -            (ESnd ext (EVar ext recT IZ)) -            (ESnd ext (EFst ext (EVar ext recT IZ)))) - diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs index 4c287d7..49ae0e6 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/EnvDescr.hs @@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env'                         -> r)           -> r  subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k =    subDescr des sub $ \des' submerge subaccum subd1e ->      case sto of -      SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) -      SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) -      SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) +      SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) +      SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) +      SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e)  subDescr (des `DPush` (_, _, sto)) (SENo sub) k =    subDescr des sub $ \des' submerge subaccum subd1e ->      case sto of @@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des  select s@SAccum (DPush des (_, _, SDiscr)) = select s des  select s@SMerge (DPush des (_, _, SDiscr)) = select s des  select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 974669d..83f013d 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -3,6 +3,7 @@  {-# LANGUAGE TypeOperators #-}  module CHAD.Types where +import AST.Accum  import AST.Types  import Data @@ -18,11 +19,11 @@ type family D1 t where  type family D2 t where    D2 TNil = TNil -  D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) +  D2 (TPair a b) = TPair (D2 a) (D2 b)    D2 (TEither a b) = TLEither (D2 a) (D2 b)    D2 (TLEither a b) = TLEither (D2 a) (D2 b)    D2 (TMaybe t) = TMaybe (D2 t) -  D2 (TArr n t) = TMaybe (TArr n (D2 t)) +  D2 (TArr n t) = TArr n (D2 t)    D2 (TScal t) = D2s t  type family D2s t where @@ -60,11 +61,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env  d2M :: STy t -> SMTy (D2 t)  d2M STNil = SMTNil -d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STPair a b) = SMTPair (d2M a) (d2M b)  d2M (STEither a b) = SMTLEither (d2M a) (d2M b)  d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)  d2M (STMaybe t) = SMTMaybe (d2M t) -d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STArr n t) = SMTArr n (d2M t)  d2M (STScal t) = case t of    STI32 -> SMTNil    STI64 -> SMTNil @@ -116,3 +117,10 @@ chcSetAccum c = c { chcLetArrayAccum = True  indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))  indexTupD1Id SZ = Refl  indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs index 9c10421..2712b08 100644 --- a/src/Data/VarMap.hs +++ b/src/Data/VarMap.hs @@ -74,7 +74,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env'  subMap subenv =    let bools = let loop :: Subenv env env' -> [Bool]                    loop SETop = [] -                  loop (SEYes sub) = True : loop sub +                  loop (SEYesR sub) = True : loop sub                    loop (SENo sub) = False : loop sub                in VS.fromList $ loop subenv        newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools @@ -89,7 +89,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env  superMap subenv =    let loop :: Subenv env env' -> Int -> [Int]        loop SETop _ = [] -      loop (SEYes sub) i = i : loop sub (i+1) +      loop (SEYesR sub) i = i : loop sub (i+1)        loop (SENo sub) i = loop sub (i+1)        newIndices = VS.fromList $ loop subenv 0 | 
