From 56056c98b2e3dce65a0e42bce0410c083fd1f8be Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 6 Jun 2025 22:50:06 +0200 Subject: WIP mixed static/dynamic sparsity --- chad-fast.cabal | 2 +- src/AST.hs | 33 +++- src/AST/Accum.hs | 17 -- src/AST/Bindings.hs | 2 +- src/AST/Count.hs | 6 +- src/AST/Env.hs | 58 ++++--- src/AST/Sparse.hs | 434 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/AST/Types.hs | 2 +- src/CHAD.hs | 298 ++++++++++++++++++++++++----------- src/CHAD/Accum.hs | 27 ---- src/CHAD/EnvDescr.hs | 20 ++- src/CHAD/Types.hs | 16 +- src/Data/VarMap.hs | 4 +- 13 files changed, 747 insertions(+), 172 deletions(-) create mode 100644 src/AST/Sparse.hs delete mode 100644 src/CHAD/Accum.hs diff --git a/chad-fast.cabal b/chad-fast.cabal index b0ed639..b8510d2 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -18,13 +18,13 @@ library AST.Count AST.Env AST.Pretty + AST.Sparse AST.SplitLets AST.Types AST.UnMonoid AST.Weaken AST.Weaken.Auto CHAD - CHAD.Accum CHAD.EnvDescr CHAD.Top CHAD.Types diff --git a/src/AST.hs b/src/AST.hs index 149cddd..0000836 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -503,11 +503,40 @@ eshapeEmpty (SS n) e = (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) -ezeroD2 :: STy t -> Ex env (D2 t) -ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = ELet ext rhs body + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 03369c8..1101cc0 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,14 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where import AST.Types -import CHAD.Types import Data @@ -75,20 +72,6 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" - -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 745a93b..2310f4b 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -69,7 +69,7 @@ collectBindings = \env -> fst . go env WId where go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) go _ _ SETop = (BTop, WId) - go (ty `SCons` env) w (SEYes sub) = + go (ty `SCons` env) w (SEYesR sub) = let (bs, w') = go env (WPop w) sub in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 0c682c6..03a36f6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -154,7 +154,7 @@ deleteUnused (_ `SCons` env) OccEnd k = deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = deleteUnused env occenv $ \sub -> case count of Zero -> k (SENo sub) - _ -> k (SEYes sub) + _ -> k (SEYesR sub) unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t unsafeWeakenWithSubenv = \sub -> @@ -163,7 +163,7 @@ unsafeWeakenWithSubenv = \sub -> Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") where sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYes _) = Just IZ + sinkViaSubenv IZ (SEYesR _) = Just IZ sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index 4f34166..bc2b9e0 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,59 +1,73 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Env where +import Data.Type.Equality + +import AST.Sparse import AST.Weaken import Data -- | @env'@ is a subset of @env@: each element of @env@ is either included in -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') - SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} -subList :: SList f env -> Subenv env env' -> SList f env' +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: SList f env -> Subenv env env +subenvAll :: IsSubType s => SList f env -> Subenv' s env env subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) +subenvAll (SCons _ env) = SEYes subtFull (subenvAll env) -subenvNone :: SList f env -> Subenv env '[] +subenvNone :: SList f env -> Subenv' s env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) +subenvOnehot :: IsSubType s => SList f env -> Idx env t -> Subenv' s env '[t] +subenvOnehot (SCons _ env) IZ = SEYes subtFull (subenvNone env) subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) subenvOnehot SNil i = case i of {} -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub sinkWithSubenv (SENo sub) = sinkWithSubenv sub -wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs new file mode 100644 index 0000000..09dbc70 --- /dev/null +++ b/src/AST/Sparse.hs @@ -0,0 +1,434 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-} +module AST.Sparse where + +import Data.Kind (Constraint, Type) +import Data.Type.Equality + +import AST + + +data Sparse t t' where + SpDense :: Sparse t t + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpLeft :: Sparse a a' -> Sparse (TLEither a b) a' + SpRight :: Sparse b b' -> Sparse (TLEither a b) b' + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpJust :: Sparse t t' -> Sparse (TMaybe t) t' + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') +deriving instance Show (Sparse t t') + +applySparse :: Sparse t t' -> STy t -> STy t' +applySparse SpDense t = t +applySparse (SpSparse s) t = STMaybe (applySparse s t) +applySparse SpAbsent _ = STNil +applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) +applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) +applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1 +applySparse (SpRight s) (STLEither _ t2) = applySparse s t2 +applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) +applySparse (SpJust s) (STMaybe t) = applySparse s t +applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + + +class IsSubType s where + type IsSubTypeSubject s (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: s a a + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ STy + subtApply = applySparse + + subtTrans SpDense s = s + subtTrans s SpDense = s + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2) + subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2) + subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2) + subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2 + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2) + subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + + subtFull = SpDense + + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This eliminates pointless checks. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. しょうがない。 +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override +sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) +sparsePlusS ST _ t SpAbsent sp2 k = + k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a) +sparsePlusS _ ST t sp1 SpAbsent k = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) +sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) +sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +-- coproducts with partially known arguments: if we have a non-nil +-- always-present coproduct argument, the result is dense, otherwise we +-- introduce sparsity +sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = + sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa -> + k (SpLeft sp3a) + (Inj inj13a) + Noinj + (\x1 x2 -> + elet x1 $ + elcase (weakenExpr WSink x2) + (inj13a (evar IZ)) + (plusa (evar (IS IZ)) (evar IZ)) + (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) + +sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = + sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa -> + k (SpSparse (SpLeft sp3a)) + (Inj $ \x1 -> EJust ext (inj13a x1)) + (Inj $ \x2 -> + elcase x2 + (ENothing ext (applySparse sp3a (fromSMTy ta))) + (EJust ext (inj23a (evar IZ))) + (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr")) + (\x1 x2 -> + elet x1 $ + EJust ext $ + elcase (weakenExpr WSink x2) + (inj13a (evar IZ)) + (plusa (evar (IS IZ)) (evar IZ)) + (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) + +sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k = + sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa) +sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = + sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb -> + k (SpRight sp3b) + (Inj inj13b) + Noinj + (\x1 x2 -> + elet x1 $ + elcase (weakenExpr WSink x2) + (inj13b (evar IZ)) + (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") + (plusb (evar (IS IZ)) (evar IZ))) + +sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = + sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb -> + k (SpSparse (SpRight sp3b)) + (Inj $ \x1 -> EJust ext (inj13b x1)) + (Inj $ \x2 -> + elcase x2 + (ENothing ext (applySparse sp3b (fromSMTy tb))) + (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll") + (EJust ext (inj23b (evar IZ)))) + (\x1 x2 -> + elet x1 $ + EJust ext $ + elcase (weakenExpr WSink x2) + (inj13b (evar IZ)) + (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") + (plusb (evar (IS IZ)) (evar IZ))) + +sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k = + sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb) +sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k + +-- dense same-branch coproducts simply recurse +sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k = + sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus -> + k (SpLeft sp3) inj1 inj2 plus +sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k = + sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus -> + k (SpRight sp3) inj1 inj2 plus + +-- dense, mismatched coproducts are valid as long as we don't actually invoke +-- plus at runtime (injections are fine) +sparsePlusS SF SF _ SpLeft{} SpRight{} k = + k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr") +sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k = + k (SpRight sp2) Noinj (Inj id) + (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr") +sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k = + k (SpLeft sp1) (Inj id) Noinj + (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll") +sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k = + -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it. + k (SpLEither sp1 sp2) + (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a) + (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b) + (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr") + +sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k = -- the errors are not flipped, but eh + sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) +sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k + +-- maybe with partially known arguments: if we have an always-present Just +-- argument, the result is dense, otherwise we introduce sparsity by weakening +-- to SpMaybe +sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = + sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus -> + k (SpJust sp3) + (Inj inj1) + Noinj + (\a b -> + elet a $ + emaybe (weakenExpr WSink b) + (inj1 (evar IZ)) + (plus (evar (IS IZ)) (evar IZ))) +sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> EJust ext (inj1 a)) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet a $ + emaybe (weakenExpr WSink b) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ)))) + +sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k = + sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus) +sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k +sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k + +-- dense same-branch maybes simply recurse +sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus -> + k (SpJust sp3) inj1 inj2 plus + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) +sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k +sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k diff --git a/src/AST/Types.hs b/src/AST/Types.hs index a3b7302..42bfb92 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -5,9 +5,9 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-} module AST.Types where import Data.Int (Int32, Int64) diff --git a/src/CHAD.hs b/src/CHAD.hs index df792ce..b5a9af0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -42,8 +42,8 @@ import AST import AST.Bindings import AST.Count import AST.Env +import AST.Sparse import AST.Weaken.Auto -import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -65,7 +65,7 @@ tapeTy (SCons t ts) = STPair t (tapeTy ts) bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) bindingsCollectTape BTop SETop _ = ENil ext -bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w = EPair ext (EVar ext t (w @> IZ)) (bindingsCollectTape binds sub (w .> WSink)) bindingsCollectTape (BPush binds _) (SENo sub) w = @@ -227,26 +227,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) -> D2Op (TScal a) t @@ -261,11 +272,11 @@ d2op op = case op of -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) STF32 -> float STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -293,7 +304,7 @@ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t @@ -317,44 +328,127 @@ conv2Idx DTop i = case i of {} ------------------------------------ MONOIDS ----------------------------------- -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) - - ------------------------------------- SUBENVS ----------------------------------- +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) +zeroTup SNil _ = ENil ext +zeroTup (t `SCons` env) w = + EPair ext (zeroTup env (WPop w)) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +----------------------------------- SPARSITY ----------------------------------- + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse _ SpDense _ e = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STEither t1 t2) (SpLeft s) epr e = + let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") + in ELInl ext (d2 t2) (expandSparse t1 s epr' e) +expandSparse (STEither t1 t2) (SpRight s) epr e = + let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) + in ELInr ext (d2 t1) (expandSparse t2 s epr' e) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLeft s) epr e = + let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") + in ELInl ext (d2 t2) (expandSparse t1 s epr' e) +expandSparse (STLEither t1 t2) (SpRight s) epr e = + let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) + in ELInr ext (d2 t1) (expandSparse t2 s epr' e) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal sty) _ _ _ = case sty of {} -- SpDense and SpSparse handled already +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +sparsePlus + :: SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) + -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2 + -> (forall env3. SubenvS (D2E env) env3 + -> SubenvS env3 env1 + -> SubenvS env3 env2 + -> (Ex exenv (Tup env1) + -> Ex exenv (Tup env2) + -> Ex exenv (Tup env3)) -> r) -> r subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = +subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> + k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 -> ELet ext e1 $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) (weakenExpr WSink e2)) (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = +subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> + k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 -> ELet ext e2 $ EPair ext (pl (weakenExpr WSink e1) (EFst ext (EVar ext (typeOf e2) IZ))) (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = +subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> + k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 -> ELet ext e1 $ ELet ext (weakenExpr WSink e2) $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) @@ -363,22 +457,44 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = (ESnd ext (EVar ext (typeOf e1) (IS IZ))) (ESnd ext (EVar ext (typeOf e2) IZ))) -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" +assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + fromArrayValId :: Maybe (ValId t) -> Maybe Int fromArrayValId (Just (VIArr i _)) = Just i fromArrayValId _ = Nothing @@ -422,7 +538,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of k (storepl `DPush` (t, vid, SAccum)) envpro prosub - (SEYes accrevsub) + (SEYesR accrevsub) (VarMap.sink1 accumMap) (\shbinds -> autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) @@ -449,7 +565,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SAccum)) (t `SCons` envpro) - (SEYes prosub) + (SEYesR prosub) (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of @@ -499,19 +615,21 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- data Ret env0 sto t = - forall shbinds tapebinds env0Merge. + forall shbinds tapebinds contribs. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (forall sd. Sparse (D2 t) sd + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) deriving instance Show (Ret env0 sto t) data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. + forall contribs. RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (forall sd. Sparse (D2 t) sd + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) deriving instance Show (RetPair env0 sto env shbinds tapebinds t) data Rets env0 sto env list = @@ -569,18 +687,24 @@ freezeRet :: Descr env sto freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (subList (d2e (select SMerge descr)) sub) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) + (ELet ext (weakenExpr (autoWeak library (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) ---------------------------- THE CHAD TRANSFORMATION --------------------------- @@ -596,21 +720,21 @@ drev des accumMap = \case Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) + (subenvOnehot (d2e (select SMerge des)) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) Idx2Di _ -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) ELet _ (rhs :: Expr _ _ a) body @@ -621,7 +745,7 @@ drev des accumMap = \case , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in + let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) (weakenExpr wbody0' body1) @@ -637,7 +761,7 @@ drev des accumMap = \case (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b @@ -649,16 +773,13 @@ drev des accumMap = \case subtape (EPair ext a1 b1) subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ)) + (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ)) EFst _ e | Ret e0 subtape e1 sub e2 <- drev des accumMap e @@ -732,7 +853,7 @@ drev des accumMap = \case ECase ext e1 (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) + (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut (ELet ext @@ -801,7 +922,7 @@ drev des accumMap = \case (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) + (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) @@ -816,7 +937,7 @@ drev des accumMap = \case `BPush` (typeOf b1, weakenExpr WSink b1) `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) + (SEYesR (SENo (SENo (SENo subtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ @@ -865,7 +986,7 @@ drev des accumMap = \case , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in + let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in @@ -894,7 +1015,7 @@ drev des accumMap = \case in EPair ext (weakenExpr w e1) (collectexpr w))) `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) + (SEYesR (SENo (SEYesR SETop))) (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvCompose subMergeUsed proSub) @@ -981,7 +1102,7 @@ drev des accumMap = \case , STArr (SS n) eltty <- typeOf e -> Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) + (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) sub @@ -1002,7 +1123,7 @@ drev des accumMap = \case Ret (binds `BPush` (STArr n (d1 eltty), e1) `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) + (SEYesR (SEYesR (SENo subtape))) (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub @@ -1030,7 +1151,7 @@ drev des accumMap = \case , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) + (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub (EMaybe ext @@ -1076,7 +1197,7 @@ drev des accumMap = \case , let tIxN = tTup (sreplicate (SS n) tIx) = Ret (e0 `BPush` (at, e1) `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) + (SEYesR (SEYesR subtape)) (EVar ext at' IZ) sub (EMaybe ext @@ -1094,16 +1215,17 @@ drev des accumMap = \case data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. + forall shbinds tapebinds contribs. RetScoped (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) + (SubenvS (D2E (Select env0 sto "merge")) contribs) -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) + (forall sd. Sparse (D2 t) sd + -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) (D2 a)))) -- ^ the merge contributions, plus the cotangent to the argument -- (if there is any) deriving instance Show (RetScoped env0 sto a s t) @@ -1118,7 +1240,7 @@ drevScoped des accumMap argty argsto argids expr = case argsto of SMerge | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 + SEYesR sub' -> RetScoped e0 subtape e1 sub' e2 SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) SAccum diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs deleted file mode 100644 index d8a71b5..0000000 --- a/src/CHAD/Accum.hs +++ /dev/null @@ -1,27 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -module CHAD.Accum where - -import AST -import CHAD.Types -import Data - - - -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = - makeAccumulators envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs index 4c287d7..49ae0e6 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/EnvDescr.hs @@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of @@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des select s@SAccum (DPush des (_, _, SDiscr)) = select s des select s@SMerge (DPush des (_, _, SDiscr)) = select s des select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 974669d..83f013d 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -3,6 +3,7 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Types where +import AST.Accum import AST.Types import Data @@ -18,11 +19,11 @@ type family D1 t where type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) + D2 (TPair a b) = TPair (D2 a) (D2 b) D2 (TEither a b) = TLEither (D2 a) (D2 b) D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TMaybe (TArr n (D2 t)) + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where @@ -60,11 +61,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env d2M :: STy t -> SMTy (D2 t) d2M STNil = SMTNil -d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STPair a b) = SMTPair (d2M a) (d2M b) d2M (STEither a b) = SMTLEither (d2M a) (d2M b) d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2M (STMaybe t) = SMTMaybe (d2M t) -d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STArr n t) = SMTArr n (d2M t) d2M (STScal t) = case t of STI32 -> SMTNil STI64 -> SMTNil @@ -116,3 +117,10 @@ chcSetAccum c = c { chcLetArrayAccum = True indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs index 9c10421..2712b08 100644 --- a/src/Data/VarMap.hs +++ b/src/Data/VarMap.hs @@ -74,7 +74,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' subMap subenv = let bools = let loop :: Subenv env env' -> [Bool] loop SETop = [] - loop (SEYes sub) = True : loop sub + loop (SEYesR sub) = True : loop sub loop (SENo sub) = False : loop sub in VS.fromList $ loop subenv newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools @@ -89,7 +89,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env superMap subenv = let loop :: Subenv env env' -> Int -> [Int] loop SETop _ = [] - loop (SEYes sub) i = i : loop sub (i+1) + loop (SEYesR sub) i = i : loop sub (i+1) loop (SENo sub) i = loop sub (i+1) newIndices = VS.fromList $ loop subenv 0 -- cgit v1.2.3-70-g09d2 From 514c4bb0bfe908ec39ab4fa09dbf51bf7db29bd4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 8 Jun 2025 22:07:17 +0200 Subject: More WIP sparsity --- src/AST/Env.hs | 24 +- src/AST/Sparse.hs | 244 +++++------------ src/AST/Weaken/Auto.hs | 2 +- src/CHAD.hs | 718 +++++++++++++++++++++++++++++-------------------- 4 files changed, 516 insertions(+), 472 deletions(-) diff --git a/src/AST/Env.hs b/src/AST/Env.hs index bc2b9e0..422f0f7 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -4,6 +4,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Env where @@ -12,6 +13,7 @@ import Data.Type.Equality import AST.Sparse import AST.Weaken +import CHAD.Types import Data @@ -38,18 +40,18 @@ subList SNil SETop = SNil subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: IsSubType s => SList f env -> Subenv' s env env +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes subtFull (subenvAll env) +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) subenvNone :: SList f env -> Subenv' s env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: IsSubType s => SList f env -> Idx env t -> Subenv' s env '[t] -subenvOnehot (SCons _ env) IZ = SEYes subtFull (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop @@ -71,3 +73,13 @@ wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env wUndoSubenv SETop = WId wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 09dbc70..ddae7fe 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -7,7 +7,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-} +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} module AST.Sparse where import Data.Kind (Constraint, Type) @@ -17,66 +17,99 @@ import AST data Sparse t t' where - SpDense :: Sparse t t SpSparse :: Sparse t t' -> Sparse t (TMaybe t') SpAbsent :: Sparse t TNil SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') - SpLeft :: Sparse a a' -> Sparse (TLEither a b) a' - SpRight :: Sparse b b' -> Sparse (TLEither a b) b' SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') - SpJust :: Sparse t t' -> Sparse (TMaybe t) t' SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) deriving instance Show (Sparse t t') -applySparse :: Sparse t t' -> STy t -> STy t' -applySparse SpDense t = t -applySparse (SpSparse s) t = STMaybe (applySparse s t) -applySparse SpAbsent _ = STNil -applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) -applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) -applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1 -applySparse (SpRight s) (STLEither _ t2) = applySparse s t2 -applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) -applySparse (SpJust s) (STMaybe t) = applySparse s t -applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t class IsSubType s where - type IsSubTypeSubject s (f :: k -> Type) :: Constraint + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' subtTrans :: s a b -> s b c -> s a c - subtFull :: s a a + subtFull :: IsSubTypeSubject s f => f t -> s t t instance IsSubType (:~:) where type IsSubTypeSubject (:~:) f = () subtApply = gcastWith subtTrans = trans - subtFull = Refl + subtFull _ = Refl instance IsSubType Sparse where - type IsSubTypeSubject Sparse f = f ~ STy + type IsSubTypeSubject Sparse f = f ~ SMTy subtApply = applySparse - subtTrans SpDense s = s - subtTrans s SpDense = s subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) subtTrans _ SpAbsent = SpAbsent subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2) - subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2) - subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2) - subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2) subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2 subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2) - subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2) subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - - subtFull = SpDense + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False data SBool b where @@ -176,7 +209,10 @@ sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t -- TODO: sparse of Just is just Maybe -- dense plus -sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b) +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) -- handle absents sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) @@ -239,8 +275,6 @@ sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = eunPair x1 $ \w1 x1a x1b -> eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) -sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k -- coproducts sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = @@ -268,107 +302,6 @@ sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k (inr (inj13b (evar IZ))) (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") (inr (plusb (evar (IS IZ)) (evar IZ))))) -sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k - --- coproducts with partially known arguments: if we have a non-nil --- always-present coproduct argument, the result is dense, otherwise we --- introduce sparsity -sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = - sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa -> - k (SpLeft sp3a) - (Inj inj13a) - Noinj - (\x1 x2 -> - elet x1 $ - elcase (weakenExpr WSink x2) - (inj13a (evar IZ)) - (plusa (evar (IS IZ)) (evar IZ)) - (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) - -sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = - sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa -> - k (SpSparse (SpLeft sp3a)) - (Inj $ \x1 -> EJust ext (inj13a x1)) - (Inj $ \x2 -> - elcase x2 - (ENothing ext (applySparse sp3a (fromSMTy ta))) - (EJust ext (inj23a (evar IZ))) - (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr")) - (\x1 x2 -> - elet x1 $ - EJust ext $ - elcase (weakenExpr WSink x2) - (inj13a (evar IZ)) - (plusa (evar (IS IZ)) (evar IZ)) - (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) - -sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k = - sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa) -sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k - -sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = - sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb -> - k (SpRight sp3b) - (Inj inj13b) - Noinj - (\x1 x2 -> - elet x1 $ - elcase (weakenExpr WSink x2) - (inj13b (evar IZ)) - (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") - (plusb (evar (IS IZ)) (evar IZ))) - -sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = - sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb -> - k (SpSparse (SpRight sp3b)) - (Inj $ \x1 -> EJust ext (inj13b x1)) - (Inj $ \x2 -> - elcase x2 - (ENothing ext (applySparse sp3b (fromSMTy tb))) - (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll") - (EJust ext (inj23b (evar IZ)))) - (\x1 x2 -> - elet x1 $ - EJust ext $ - elcase (weakenExpr WSink x2) - (inj13b (evar IZ)) - (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") - (plusb (evar (IS IZ)) (evar IZ))) - -sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k = - sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb) -sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k - --- dense same-branch coproducts simply recurse -sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k = - sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus -> - k (SpLeft sp3) inj1 inj2 plus -sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k = - sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus -> - k (SpRight sp3) inj1 inj2 plus - --- dense, mismatched coproducts are valid as long as we don't actually invoke --- plus at runtime (injections are fine) -sparsePlusS SF SF _ SpLeft{} SpRight{} k = - k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr") -sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k = - k (SpRight sp2) Noinj (Inj id) - (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr") -sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k = - k (SpLeft sp1) (Inj id) Noinj - (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll") -sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k = - -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it. - k (SpLEither sp1 sp2) - (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a) - (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b) - (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr") - -sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k = -- the errors are not flipped, but eh - sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus) -- maybe sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = @@ -385,42 +318,6 @@ sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = (emaybe (evar (IS IZ)) (EJust ext (inj1 (evar IZ))) (EJust ext (plus (evar (IS IZ)) (evar IZ))))) -sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k - --- maybe with partially known arguments: if we have an always-present Just --- argument, the result is dense, otherwise we introduce sparsity by weakening --- to SpMaybe -sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = - sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus -> - k (SpJust sp3) - (Inj inj1) - Noinj - (\a b -> - elet a $ - emaybe (weakenExpr WSink b) - (inj1 (evar IZ)) - (plus (evar (IS IZ)) (evar IZ))) -sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpMaybe sp3) - (Inj $ \a -> EJust ext (inj1 a)) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet a $ - emaybe (weakenExpr WSink b) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ)))) - -sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k = - sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus) -sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k - --- dense same-branch maybes simply recurse -sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k = - sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus -> - k (SpJust sp3) inj1 inj2 plus -- dense array cotangents simply recurse sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = @@ -430,5 +327,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) -sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k -sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 6752c24..c6efe37 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil auto :: KnownListSpine list => SList (Const ()) list diff --git a/src/CHAD.hs b/src/CHAD.hs index b5a9af0..241825e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -11,6 +11,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -62,15 +63,21 @@ tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds - -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollectTape BTop SETop _ = ENil ext -bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w = +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = EPair ext (EVar ext t (w @> IZ)) (bindingsCollectTape binds sub (w .> WSink)) -bindingsCollectTape (BPush binds _) (SENo sub) w = +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = bindingsCollectTape binds sub (w .> WSink) +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) + -- In order from large to small: i.e. in reverse order from what we want, -- because in a Bindings, the head of the list is the bottom-most entry. type family TapeUnfoldings binds where @@ -325,6 +332,21 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) = Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" + ------------------------------------ MONOIDS ----------------------------------- @@ -355,7 +377,7 @@ subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) subenvD1E (SENo sub) = SENo (subenvD1E sub) expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) -expandSparse _ SpDense _ e = e +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e expandSparse t (SpSparse sp) epr e = EMaybe ext (EZero ext (d2M t) (d2zeroInfo t epr)) @@ -376,12 +398,6 @@ expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = (ECase ext (weakenExpr WSink epr) (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) -expandSparse (STEither t1 t2) (SpLeft s) epr e = - let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") - in ELInl ext (d2 t2) (expandSparse t1 s epr' e) -expandSparse (STEither t1 t2) (SpRight s) epr e = - let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) - in ELInr ext (d2 t1) (expandSparse t2 s epr' e) expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = ELCase ext e (EZero ext (d2M (STEither t1 t2)) (ENil ext)) @@ -393,12 +409,6 @@ expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) -expandSparse (STLEither t1 t2) (SpLeft s) epr e = - let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL") - in ELInl ext (d2 t2) (expandSparse t1 s epr' e) -expandSparse (STLEither t1 t2) (SpRight s) epr e = - let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ) - in ELInr ext (d2 t1) (expandSparse t2 s epr' e) expandSparse (STMaybe t) (SpMaybe s) epr e = EMaybe ext (ENothing ext (d2 t)) @@ -407,55 +417,72 @@ expandSparse (STMaybe t) (SpMaybe s) epr e = e expandSparse (STArr _ t) (SpArr s) epr e = ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e -expandSparse (STScal sty) _ _ _ = case sty of {} -- SpDense and SpSparse handled already +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" -sparsePlus - :: SMTy t -> Sparse t t1 -> Sparse t t2 - -> (forall t3. Sparse t t3 - -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) - -> r) - -> r -sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus - -subenvPlus :: SList STy env - -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2 - -> (forall env3. SubenvS (D2E env) env3 - -> SubenvS env3 env1 - -> SubenvS env3 env2 - -> (Ex exenv (Tup env1) - -> Ex exenv (Tup env2) - -> Ex exenv (Tup env3)) +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext (d2M t) - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) + +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) @@ -470,10 +497,10 @@ expandSubenvZeros w (SCons t ts) (SENo sub) e = (expandSubenvZeros (WPop w) ts sub e) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] +assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty" +assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- @@ -523,8 +550,8 @@ accumPromote :: forall dt env sto proxy r. -- accumulators. -> (forall shbinds. SList STy shbinds - -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) -- ^ A weakening that converts a computation in the -- revised environment to one in the original environment -- extended with some accumulators. @@ -541,11 +568,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of (SEYesR accrevsub) (VarMap.sink1 accumMap) (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -582,7 +609,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do @@ -614,23 +641,41 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- -data Ret env0 sto t = +data Ret env0 sto sd t = forall shbinds tapebinds contribs. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) (SubenvS (D2E (Select env0 sto "merge")) contribs) - (forall sd. Sparse (D2 t) sd - -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) -deriving instance Show (Ret env0 sto t) - -data RetPair env0 sto env shbinds tapebinds t = - forall contribs. - RetPair (Ex (Append shbinds env) (D1 t)) - (SubenvS (D2E (Select env0 sto "merge")) contribs) - (forall sd. Sparse (D2 t) sd - -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) + +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) data Rets env0 sto env list = forall shbinds tapebinds. @@ -639,8 +684,11 @@ data Rets env0 sto env list = (SList (RetPair env0 sto env shbinds tapebinds) list) deriving instance Show (Rets env0 sto env list) +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list @@ -648,46 +696,47 @@ weakenRets w (Rets binds tapesub list) = let (binds', _) = weakenBindings weakenExpr w binds in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. Descr env0 sto -> SList f b1 -> SList f b2 -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - d) - -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) + +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings b) (retConcat descr list) + <- weakenRets (sinkWithBindings e0) (retConcat descr list) , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat b binds) + = Rets (bconcat e0 binds) (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) subtape subtape2) pairs)) freezeRet :: Descr env sto - -> Ret env sto t + -> Ret env sto (D2 t) t -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 - tContribs = tTup (subList (d2e (select SMerge descr)) sub) + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) library = #d (auto1 @(D2 t)) &. #tape (subList (bindingsBinds e0) subtape) &. #shbinds (bindingsBinds e0) @@ -709,11 +758,34 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = ---------------------------- THE CHAD TRANSFORMATION --------------------------- -drev :: forall env sto t. +drev :: forall env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (evar IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case EVar _ t i -> case conv2Idx des i of Idx2Ac accI -> @@ -721,14 +793,15 @@ drev des accumMap = \case SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) - (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (d2e (select SMerge des)) tupI) - (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) Idx2Di _ -> Ret BTop @@ -738,20 +811,22 @@ drev des accumMap = \case (ENil ext) ELet _ (rhs :: Expr _ _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs - , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) + (subenvConcat subtapeRHS subtapeBody) (weakenExpr wbody0' body1) subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) @@ -761,14 +836,15 @@ drev des accumMap = \case (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body - (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ) + (EVar ext (contribTupTy des subRHS) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b - | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil - , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> - subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> Ret binds subtape (EPair ext a1 b1) @@ -778,147 +854,155 @@ drev des accumMap = \case ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) (weakenExpr (WCopy (WSink .> WSink)) b2)) $ plus_A_B - (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ)) - (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ)) + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> Ret e0 subtape (EFst ext e1) sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> Ret e0 subtape (ESnd ext e1) sub - (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInl ext (d1 t2) e1) - sub + sub' (ELCase ext - (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) - (zeroTup (subList (select SMerge des) sub)) - (weakenExpr (WCopy WSink) e2) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInr ext (d1 t1) e1) - sub + sub' (ELCase ext - (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) - (zeroTup (subList (select SMerge des) sub)) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy WSink) e2)) + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) ECase _ e (a :: Expr _ _ t) b - | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) - , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollectTape a0 subtapeA - , let collectB = bindingsCollectTape b0 subtapeB + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) -> - subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> - subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> Ret (e0 `BPush` (tPrimal, ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))) (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut - (ELet ext + (elet (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ + (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ in letBinds rebinds $ ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #ta0 (subList (bindingsBinds a0) subtapeA) + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA &. #prea0 prerebinds - &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #ta0 :++: #tl) (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) a2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (ELInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ in letBinds rebinds $ ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tb0 (subList (bindingsBinds b0) subtapeB) + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB &. #preb0 prerebinds - &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #tb0 :++: #tl) (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) b2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (ELInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ - ELet ext - (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) EConst _ t val -> Ret BTop SETop (EConst ext t val) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> case d2op op of Linear d2opfun -> Ret e0 subtape (d1op op e1) sub - (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> Ret (e0 `BPush` (d1 (typeOf e), e1)) @@ -926,36 +1010,51 @@ drev des accumMap = \case (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - ECustom _ _ _ storety _ pr du a b + ECustom _ _ tb storety srce pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> - Ret (binds `BPush` (typeOf a1, a1) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYesR (SENo (SENo (SENo subtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) - - -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) + `BPush` (typeOf b1, weakenExpr WSink b1) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) + `BPush` (typeOf b1, weakenExpr WSink b1) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + ERecompute _ e -> deleteUnused (descrList des) (occCountAll e) $ \usedSub -> let smallE = unsafeWeakenWithSubenv usedSub e in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in Ret (collectBindings (desD1E des) subD1eUsed) (subenvAll (desD1E usedDes)) - (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1) - (subenvCompose subMergeUsed sub) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ weakenExpr - (autoWeak (#d (auto1 @(D2 t)) + (autoWeak (#d (auto1 @sd) &. #shbinds (bindingsBinds e0) &. #tape (subList (bindingsBinds e0) subtape) &. #d1env (desD1E usedDes) @@ -970,31 +1069,32 @@ drev des accumMap = \case Ret BTop SETop (EError ext (d1 t) s) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EConstArr _ n t val -> Ret BTop SETop (EConstArr ext n t val) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result + | SpArr @_ @sdElt sdElt <- sd , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> + case lemAppendNil @e_binds of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollectTape e0 subtapeE in - Ret (BTop `BPush` (shty, letBinds she0 she1) + let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in + Ret (BTop `BPush` (shty, drevPrimal des she) `BPush` (STArr ndim (STPair (d1 eltty) tapety) ,EBuild ext ndim (EVar ext shty IZ) @@ -1012,58 +1112,59 @@ drev des accumMap = \case &. #d1env' (desD1E usedDes)) (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#e0 :++: #ix :++: #sh :++: #d1env) - in EPair ext (weakenExpr w e1) (collectexpr w))) + w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) + in EPair ext (weakenExpr w e1) (collectexpr w'))) `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) (SEYesR (SENo (SEYesR SETop))) (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - EMaybe ext - (zeroTup envPro) - (ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim (D2 eltty))) - &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - (EVar ext (d2 (STArr ndim eltty)) IZ)) - }} + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in + ESnd ext $ + uninvertTup (d2e envPro) (STArr ndim STNil) $ + -- TODO: what's happening here is that because of the sparsity + -- rewrite, makeAccumulators needs primals where it previously + -- didn't. The build derivative is currently not saving those + -- primals, so the hole below cannot currently be filled. The + -- appropriate primals (waves hands) need to be stored, so that a + -- weakening can be provided here. + makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#d (auto1 @sdElt) + &. #pro (d2ace envPro) + &. #etape (subList (bindingsBinds e0) subtapeE) + &. #prerebinds prerebinds + &. #tape (auto1 @(Tape e_tape)) + &. #ix (auto1 @shty) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) + &. #sh (auto1 @shty) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des))) + (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) + .> wPro (subList (bindingsBinds e0) subtapeE)) + e2) + }}} EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> Ret e0 subtape (EUnit ext e1) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) EReplicate1Inner _ en e -- We're allowed to ignore en2 here because the output of 'ei' is discrete. @@ -1177,7 +1278,6 @@ drev des accumMap = \case ELCase{} -> err_unsupported "ELCase" EWith{} -> err_accum - EAccum{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid @@ -1189,7 +1289,8 @@ drev des accumMap = \case deriv_extremum :: ScalIsNumeric t' ~ True => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) + -> Sparse (TArr n (D2s t')) sd' + -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t')) deriv_extremum extremum e | Ret e0 subtape e1 sub e2 <- drev des accumMap e , at@(STArr (SS n) t@(STScal st)) <- typeOf e @@ -1212,70 +1313,103 @@ drev des accumMap = \case weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) (EVar ext (d2 at') IZ)) + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) -data RetScoped env0 sto a s t = - forall shbinds tapebinds contribs. +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. RetScoped (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv shbinds tapebinds) + (Subenv (Append shbinds '[D1 a]) tapebinds) (Ex (Append shbinds (D1E (a : env0))) (D1 t)) (SubenvS (D2E (Select env0 sto "merge")) contribs) -- ^ merge contributions to the _enclosing_ merge environment - (forall sd. Sparse (D2 t) sd - -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup contribs) - (TPair (Tup contribs) (D2 a)))) + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) -- ^ the merge contributions, plus the cotangent to the argument -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) +deriving instance Show (RetScoped env0 sto a s sd t) -drevScoped :: forall a s env sto t. +drevScoped :: forall a s env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd -> Expr ValId (a : env) t - -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of SMerge - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> case sub of - SEYesR sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) SAccum | Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap , Just Refl <- testEquality foundTy (STAccum (d2M argty)) - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> - RetScoped e0 subtape e1 sub $ + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) + weakenExpr (autoWeak (#d (auto1 @sd) &. #body (subList (bindingsBinds e0) subtape) &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) - -- Our contribution to the binding's cotangent _here_ is - -- zero, because we're contributing to an earlier binding - -- of the same value instead. - (EPair ext e2 (ezeroD2 argty)) + (EPair ext e2 (ENil ext)) | let accumMap' = case argids of Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> - RetScoped e0 subtape e1 sub $ - EWith ext (d2M argty) (ezeroD2 argty) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des))) + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) e2 SDiscr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - RetScoped e0 subtape e1 sub e2 + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- chadD1Id (typeOf e) + , Refl <- chadD1EId (descrList des) + = mapExt (const ext) e + where + chadD1Id :: STy a -> D1 a :~: a + chadD1Id STNil = Refl + chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl + chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl + chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl + chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl + chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl + chadD1Id (STScal _) = Refl + chadD1Id STAccum{} = error "accumulators not allowed in source program" + + chadD1EId :: SList STy l -> D1E l :~: l + chadD1EId SNil = Refl + chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl -- cgit v1.2.3-70-g09d2 From eed0f2999d6f6c8485ef53deb38f9d0a67b4f88e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 9 Jun 2025 23:07:36 +0200 Subject: WIP --- src/AST.hs | 45 +++++++++-------- src/AST/Sparse.hs | 35 +++++++++++++- src/CHAD.hs | 141 +++++++++++++++++++++++++++--------------------------- 3 files changed, 127 insertions(+), 94 deletions(-) diff --git a/src/AST.hs b/src/AST.hs index 0000836..b2ddbb4 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -461,27 +461,30 @@ eidxEq (SS n) a b (eidxEq n (EFst ext (EVar ext ty (IS IZ))) (EFst ext (EVar ext ty IZ))) -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f - -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = - let STArr n t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr n t <- typeOf arr + , Dict <- styKnown t + = ELet ext arr $ + EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) f + +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr n t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELet ext arr1 $ + ELet ext (weakenExpr WSink arr2) $ + EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ + ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ + weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) ezip arr1 arr2 = diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index ddae7fe..369d395 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -111,6 +111,37 @@ isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s isAbsent SpScal = False +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent _ _ = ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + data SBool b where SF :: SBool False @@ -120,7 +151,7 @@ deriving instance Show (SBool b) data Injection sp a b where -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that -- 'sparsePlusS' can provide injections even if the caller doesn't require - -- them. This eliminates pointless checks. + -- them. This simplifies the sparsePlusS code. Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b Noinj :: Injection False a b @@ -138,7 +169,7 @@ withInj2 Noinj _ _ = Noinj withInj2 _ Noinj _ = Noinj -- | This function produces quadratically-sized code in the presence of nested --- dynamic sparsity. しょうがない。 +-- dynamic sparsity. TODO can this be improved? sparsePlusS :: SBool inj1 -> SBool inj2 -> SMTy t -> Sparse t t1 -> Sparse t t2 diff --git a/src/CHAD.hs b/src/CHAD.hs index 241825e..7cd4c26 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1094,42 +1094,39 @@ drev des accumMap sd = \case case lemAppendNil @e_binds of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in - Ret (BTop `BPush` (shty, drevPrimal des she) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) - in EPair ext (weakenExpr w e1) (collectexpr w'))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYesR (SENo (SEYesR SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + Ret (mergePrimalBindings + `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)) + `BPush` (STArr ndim (STPair (d1 eltty) tapety) + ,EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #propr :++: #d1env)) + e0)) $ + let w = autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) + w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) + in EPair ext (weakenExpr w e1) (collectexpr w'))) + `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) + (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) - (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in + (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in ESnd ext $ uninvertTup (d2e envPro) (STArr ndim STNil) $ - -- TODO: what's happening here is that because of the sparsity - -- rewrite, makeAccumulators needs primals where it previously - -- didn't. The build derivative is currently not saving those - -- primals, so the hole below cannot currently be filled. The - -- appropriate primals (waves hands) need to be stored, so that a - -- weakening can be provided here. - makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $ + makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -- the cotangent for this element ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) @@ -1148,10 +1145,11 @@ drev des accumMap sd = \case &. #darr (auto1 @(TArr ndim sdElt)) &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) &. #sh (auto1 @shty) + &. #propr (d1e envPro) &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des))) (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) .> wPro (subList (bindingsBinds e0) subtapeE)) e2) }}} @@ -1167,32 +1165,34 @@ drev des accumMap sd = \case weakenExpr (WCopy WSink) e2) EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (ezeroD2 eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (sparsePlus (d2M eltty) sdElt' + (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) + (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e , STArr _ t <- typeOf e -> Ret e0 subtape (EIdx0 ext e1) sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" {- @@ -1214,26 +1214,25 @@ drev des accumMap sd = \case -} EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e , Refl <- indexTupD1Id n - , Refl <- lemZeroInfoD2 eltty - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYesR (SEYesR (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `BPush` (STArr n (d1 eltty), e1) + `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) + `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere) + (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, -- cgit v1.2.3-70-g09d2 From 2b1a40b5933b8b0dceaae744e5b70cb604822c9d Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 16 Jun 2025 23:21:55 +0200 Subject: CHAD.hs compiles --- src/AST.hs | 24 ++++++- src/AST/Accum.hs | 36 ++++++++--- src/AST/UnMonoid.hs | 2 +- src/CHAD.hs | 167 +++++++++++++++++++++++++++++++----------------- src/CHAD/Top.hs | 1 - src/CHAD/Types/ToTan.hs | 18 ++---- src/Interpreter.hs | 39 +++++++++-- src/Language.hs | 2 +- src/Language/AST.hs | 2 +- src/Simplify.hs | 106 +++++++++++++++++------------- 10 files changed, 261 insertions(+), 136 deletions(-) diff --git a/src/AST.hs b/src/AST.hs index b2ddbb4..c24e3e7 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -92,12 +92,12 @@ data Expr x env t where -- accumulation effect on monoids EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t -- interface of abstract monoidal types ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) @@ -523,6 +523,14 @@ eunPair e k = (EFst ext (evar IZ)) (ESnd ext (evar IZ)) +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) @@ -543,3 +551,15 @@ elcase e a b c evar :: KnownTy a => Idx env a -> Ex env a evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 1101cc0..158b4d9 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where @@ -32,21 +33,36 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) - AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) - AcIdx (APLeft p) (TLEither a b) = AcIdx p a - AcIdx (APRight p) (TLEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = +type data StillDense = AI_D | AI_S +data SStillDense dense where + SAI_D :: SStillDense AI_D + SAI_S :: SStillDense AI_S +deriving instance Show (SStillDense dense) + +type family AcIdx dense p t where + AcIdx dense APHere t = TNil + AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a + AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b + AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b) + AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b) + AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a + AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b + AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a + AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a) + AcIdx AI_S (APArrIdx p) (TArr n a) = -- ((index, shapes info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = + (AcIdx AI_S p a) + -- AcIdx AI_D (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AI_S (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AI_D p t +type AcIdxS p t = AcIdx AI_S p t + acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ac4d733..389dd5a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -105,7 +105,7 @@ plus (SMTArr _ t) a b = a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ diff --git a/src/CHAD.hs b/src/CHAD.hs index 7cd4c26..3dedec3 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -362,12 +362,6 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" -zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) -zeroTup SNil _ = ENil ext -zeroTup (t `SCons` env) w = - EPair ext (zeroTup env (WPop w)) - (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) - ----------------------------------- SPARSITY ----------------------------------- @@ -780,7 +774,7 @@ drev des accumMap (SpSparse sd) = subtape e1 sub' - (emaybe (evar IZ) + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) (inj2 (ENil ext)) (inj1 (weakenExpr (WCopy WSink) e2))) } @@ -794,7 +788,8 @@ drev des accumMap sd = \case (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) - in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx -> + EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI))) Idx2Me tupI -> Ret BTop @@ -1227,43 +1222,45 @@ drev des accumMap sd = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext)) - (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) } EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1286,35 +1283,35 @@ drev des accumMap sd = \case err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Sparse (TArr n (D2s t')) sd' - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYesR (SEYesR subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s sd t = @@ -1379,7 +1376,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des)) in - RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $ + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library @@ -1412,3 +1409,59 @@ drevPrimal des e chadD1EId :: SList STy l -> D1E l :~: l chadD1EId SNil = Refl chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl + +accumulateSparse + :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil) + -> Ex env TNil +accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of + (_, _, s) | Just Refl <- isDense topty s -> + accum WId SAPHere arg (ENil ext) + (_, SMTScal _, SpScal) -> + accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh + (_, _, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w))) + (_, _, SpAbsent) -> + ENil ext + (SAI_D, SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SAI_S, SMTPair{}, SpPair{}) -> + error "TODO: accumulating into pair inside coproduct unimplemented" + -- There are two different ways this can be accomplished: + -- 1. Ensure we have the requisite ZeroInfo here. This means that an + -- accum-mode variable reference will (if its incoming cotangent is + -- sparse enough) need to store some ZeroInfo fragments computed from + -- the primal (not necessarily the entire primal). Doing this properly, + -- i.e. not just storing a full D1 but only the required ZeroInfo + -- fragments, is possible and not too inefficient but a bit of + -- engineering again. + -- 2. When creating an accumulator, don't initialise it with a generic + -- EZero based on a ZeroInfo, but instead a special "deep zero" based on + -- probably a full D1. This deep zero also initialises Left/Right/Just + -- modelled after the primal. With this, an accumulation needs no zero + -- info whatsoever (!) under the assumption that it receives a cotangent + -- that is compatible with the primal it is propagated back to. + (_, SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (_, SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SAI_D, SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse dense t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $ + ENil ext + (SAI_S, SMTArr{}, SpArr{}) -> + error "TODO: accumulating into array inside coproduct unimplemented" + -- See the pair case above, same reasoning diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 261ddfe..130174a 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -15,7 +15,6 @@ import AST import AST.SplitLets import AST.Weaken.Auto import CHAD -import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 8476712..888fed4 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of @@ -34,14 +32,12 @@ toTan typ primal der = case typ of (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 803a24a..b3576ce 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -162,7 +162,7 @@ interpret'Rec env = \case idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t p accum idx val + accumAddSparseD t p accum idx val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi @@ -239,7 +239,7 @@ addM typ a b = case typ of | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b -onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) @@ -274,7 +274,7 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) +newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a) newAcSparse typ prj idx val = case (typ, prj) of (_, SAPHere) -> newAcDense typ val @@ -291,9 +291,9 @@ newAcSparse typ prj idx val = case (typ, prj) of (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ziarr @@ -329,7 +329,34 @@ accumAddDense typ ref val = case typ of accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s () +accumAddSparseD typ prj ref idx val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) + (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) + (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (newAcSparse t1 prj' idx val) + (\ac -> accumAddSparse t1 prj' ac idx val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s () accumAddSparse typ prj ref idx val = case (typ, prj) of (_, SAPHere) -> accumAddDense typ ref val diff --git a/src/Language.hs b/src/Language.hs index 7a780a0..63279df 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -175,7 +175,7 @@ recompute = NERecompute with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) with a (n :-> b) = NEWith (knownMTy @t) a n b -accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil accum p a b c = NEAccum knownMTy p a b c diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 7e074df..92792b3 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -76,7 +76,7 @@ data NExpr env t where -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a diff --git a/src/Simplify.hs b/src/Simplify.hs index e110206..d3b850f 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -226,19 +226,19 @@ simplify'Rec = \case e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2') (acted $ return (ENil ext)) (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) (\e -> acted $ return e) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -373,27 +373,27 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm env p a b where - OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) +data OneHotTerm dense env p a b where + OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b +deriving instance Show (OneHotTerm dense env p a b) -simplifyOneHotTerm :: OneHotTerm env p a b +simplifyOneHotTerm :: OneHotTerm dense env p a b -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) + -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r) -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do +simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 case val1' of EZero{} -> kzero EOneHot _ t2 prj2 idx2 val2 | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k + concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k _ -> case prj1 of SAPHere -> ktriv val1 - _ -> k (OneHotTerm t1 prj1 idx1 val1) + _ -> k (OneHotTerm dense t1 prj1 idx1 val1) -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) @@ -433,52 +433,66 @@ recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (SMTPair a _, SAPFst prj1') -> - concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> +concatOneHots :: SStillDense dense -> SMTy a + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b) + -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r +concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of + (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2) + (SAI_S, _, SAPHere) -> k prj2 idx2 + + (SAI_D, SMTPair a _, SAPFst prj1') -> + concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> + k (SAPFst prj12) idx12 + (SAI_S, SMTPair a _, SAPFst prj1') -> + concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SMTPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + (SAI_D, SMTPair _ b, SAPSnd prj1') -> + concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> + k (SAPSnd prj12) idx12 + (SAI_S, SMTPair _ b, SAPSnd prj1') -> + concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (SMTLEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (SMTLEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 + (_, SMTLEither a _, SAPLeft prj1') -> + concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 + (_, SMTLEither _ b, SAPRight prj1') -> + concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - (SMTMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 + (_, SMTMaybe a, SAPJust prj1') -> + concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - (SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + -- yes, twice the same code, but we need a concrete denseness indicator to + -- reduce AcIdx (the only difference between the dense and sparse versions is + -- whether there extra info also contains an array shape, and this code + -- handles the extra info uniformly) + (SAI_D, SMTArr _ a, SAPArrIdx prj1') -> + concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + (SAI_S, SMTArr _ a, SAPArrIdx prj1') -> + concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither{}, SAPLeft{}) -> e + (SMTLEither{}, SAPRight{}) -> e + (SMTMaybe{}, SAPJust{}) -> e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext -- cgit v1.2.3-70-g09d2 From d1b2e2c3a3cdaf49ff5e4bae6fe9b0612c3779c2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 00:00:11 +0200 Subject: Tests pass, should check if output is sensible --- chad-fast.cabal | 2 + src/AST.hs | 20 ++- src/AST/Accum.hs | 58 +++++---- src/AST/Count.hs | 2 +- src/AST/Pretty.hs | 21 +++- src/AST/Sparse.hs | 110 +---------------- src/AST/Sparse/Types.hs | 107 ++++++++++++++++ src/AST/SplitLets.hs | 2 +- src/AST/UnMonoid.hs | 111 ++++++++++++++++- src/Analysis/Identity.hs | 4 +- src/CHAD.hs | 117 +----------------- src/CHAD/Accum.hs | 45 +++++++ src/CHAD/Top.hs | 54 ++++----- src/CHAD/Types.hs | 16 +++ src/Compile.hs | 171 ++++++-------------------- src/Data.hs | 8 +- src/Example.hs | 3 +- src/Interpreter.hs | 151 ++++++++++------------- src/Language.hs | 6 +- src/Language/AST.hs | 5 +- src/Simplify.hs | 309 +++++++++++++++++++++++++++++++---------------- test/Main.hs | 29 +++-- 22 files changed, 726 insertions(+), 625 deletions(-) create mode 100644 src/AST/Sparse/Types.hs create mode 100644 src/CHAD/Accum.hs diff --git a/chad-fast.cabal b/chad-fast.cabal index b8510d2..b7270e4 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -19,12 +19,14 @@ library AST.Env AST.Pretty AST.Sparse + AST.Sparse.Types AST.SplitLets AST.Types AST.UnMonoid AST.Weaken AST.Weaken.Auto CHAD + CHAD.Accum CHAD.EnvDescr CHAD.Top CHAD.Types diff --git a/src/AST.hs b/src/AST.hs index c24e3e7..5aab4fc 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -25,6 +25,7 @@ import Data.Kind (Type) import Array import AST.Accum +import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types @@ -91,11 +92,16 @@ data Expr x env t where ERecompute :: x t -> Expr x env t -> Expr x env t -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t @@ -218,9 +224,10 @@ typeOf = \case ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil + EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t @@ -261,8 +268,9 @@ extOf = \case ECustom x _ _ _ _ _ _ _ _ -> x ERecompute x _ -> x EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x + EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x @@ -306,8 +314,9 @@ travExt f = \case ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 - EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s @@ -364,8 +373,9 @@ subst' f w = \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 158b4d9..619c2b1 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} @@ -33,35 +34,38 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type data StillDense = AI_D | AI_S -data SStillDense dense where - SAI_D :: SStillDense AI_D - SAI_S :: SStillDense AI_S -deriving instance Show (SStillDense dense) +type data AIDense = AID | AIS -type family AcIdx dense p t where - AcIdx dense APHere t = TNil - AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a - AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b - AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b) - AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b) - AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a - AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b - AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a - AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a) - AcIdx AI_S (APArrIdx p) (TArr n a) = - -- ((index, shapes info), recursive info) +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx AI_S p a) - -- AcIdx AI_D (APArrSlice m) (TArr n a) = + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = -- -- index -- Tup (Replicate m TIx) - -- AcIdx AI_S (APArrSlice m) (TArr n a) = + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -type AcIdxD p t = AcIdx AI_D p t -type AcIdxS p t = AcIdx AI_S p t +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t @@ -88,6 +92,16 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 03a36f6..05be524 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -134,7 +134,7 @@ occCountGeneral onehot unpush alter many = go WId ECustom _ _ _ _ _ _ _ a b -> re a <> re b ERecompute _ e -> re e EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e + EAccum _ _ _ a _ b e -> re a <> re b <> re e EZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 41da656..fef9686 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -25,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count +import AST.Sparse.Types import CHAD.Types import Data @@ -304,18 +305,24 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] EZero _ t e1 -> do e1' <- ppExpr' 11 val e1 return $ ppParen (d > 0) $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b @@ -368,6 +375,16 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 369d395..f0a1f2a 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -1,116 +1,19 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse where +module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where -import Data.Kind (Constraint, Type) import Data.Type.Equality import AST +import AST.Sparse.Types +import Data (SBool(..)) -data Sparse t t' where - SpSparse :: Sparse t t' -> Sparse t (TMaybe t') - SpAbsent :: Sparse t TNil - - SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') - SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') - SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') - SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') - SpScal :: Sparse (TScal t) (TScal t) -deriving instance Show (Sparse t t') - -class ApplySparse f where - applySparse :: Sparse t t' -> f t -> f t' - -instance ApplySparse STy where - applySparse (SpSparse s) t = STMaybe (applySparse s t) - applySparse SpAbsent _ = STNil - applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) - applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) - applySparse SpScal t = t - -instance ApplySparse SMTy where - applySparse (SpSparse s) t = SMTMaybe (applySparse s t) - applySparse SpAbsent _ = SMTNil - applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) - applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) - applySparse SpScal t = t - - -class IsSubType s where - type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint - subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c - subtFull :: IsSubTypeSubject s f => f t -> s t t - -instance IsSubType (:~:) where - type IsSubTypeSubject (:~:) f = () - subtApply = gcastWith - subtTrans = trans - subtFull _ = Refl - -instance IsSubType Sparse where - type IsSubTypeSubject Sparse f = f ~ SMTy - subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - - subtFull = spDense - -spDense :: SMTy t -> Sparse t t -spDense SMTNil = SpAbsent -spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) -spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) -spDense (SMTMaybe t) = SpMaybe (spDense t) -spDense (SMTArr _ t) = SpArr (spDense t) -spDense (SMTScal _) = SpScal - -isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') -isDense SMTNil SpAbsent = Just Refl -isDense _ SpSparse{} = Nothing -isDense _ SpAbsent = Nothing -isDense (SMTPair t1 t2) (SpPair s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTLEither t1 t2) (SpLEither s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTMaybe t) (SpMaybe s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTArr _ t) (SpArr s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTScal _) SpScal = Just Refl - -isAbsent :: Sparse t t' -> Bool -isAbsent (SpSparse s) = isAbsent s -isAbsent SpAbsent = True -isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpMaybe s) = isAbsent s -isAbsent (SpArr s) = isAbsent s -isAbsent SpScal = False - sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' sparsePlus _ SpAbsent _ _ = ENil ext sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 @@ -143,11 +46,6 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 -data SBool b where - SF :: SBool False - ST :: SBool True -deriving instance Show (SBool b) - data Injection sp a b where -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that -- 'sparsePlusS' can provide injections even if the caller doesn't require diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs new file mode 100644 index 0000000..10cac4e --- /dev/null +++ b/src/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AST.Sparse.Types where + +import AST.Types + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 3c353d4..2dad17a 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -63,7 +63,7 @@ splitLets' = \sub -> \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 389dd5a..d498aaa 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -1,18 +1,22 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where +module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import AST +import AST.Sparse.Types import Data --- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them --- into their concrete implementations. +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -49,7 +53,10 @@ unMonoid = \case ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t @@ -66,6 +73,27 @@ zero (SMTScal t) _ = case t of STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil _ = ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t plus SMTNil _ _ = ENil ext plus (SMTPair t1 t2) a b = @@ -143,3 +171,78 @@ onehot typ topprj idx arg = case (typ, topprj) of (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 4501c32..2fd321d 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -307,11 +307,11 @@ idana env expr = case expr of let res = VIPair v2 x2 pure (res, EWith res t e1' e2') - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' e2' e3') + pure (VINil, EAccum VINil t prj e1' sp e2' e3') EZero _ t e1 -> do -- Approximate the result of EZero to be independent from the zero info diff --git a/src/CHAD.hs b/src/CHAD.hs index 3dedec3..621aa3e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -34,7 +34,6 @@ module CHAD ( import Data.Functor.Const import Data.Some -import Data.Type.Bool (If) import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) @@ -45,6 +44,7 @@ import AST.Count import AST.Env import AST.Sparse import AST.Weaken.Auto +import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -348,28 +348,8 @@ opt2UnSparse = go . opt2 go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" ------------------------------------- MONOIDS ----------------------------------- - -d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) -d2zeroInfo STNil _ = ENil ext -d2zeroInfo (STPair a b) e = - eunPair e $ \_ e1 e2 -> - EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) -d2zeroInfo STEither{} _ = ENil ext -d2zeroInfo STLEither{} _ = ENil ext -d2zeroInfo STMaybe{} _ = ENil ext -d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e -d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext -d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" - - ----------------------------------- SPARSITY ----------------------------------- -subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') -subenvD1E SETop = SETop -subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) -subenvD1E (SENo sub) = SENo (subenvD1E sub) - expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e expandSparse t (SpSparse sp) epr e = @@ -499,23 +479,6 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- -makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators _ SNil e = e -makeAccumulators w (t `SCons` envpro) e = - makeAccumulators (WPop w) envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - fromArrayValId :: Maybe (ValId t) -> Maybe Int fromArrayValId (Just (VIArr i _)) = Just i fromArrayValId _ = Nothing @@ -788,8 +751,7 @@ drev des accumMap sd = \case (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) - in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx -> - EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI))) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop @@ -1275,6 +1237,7 @@ drev des accumMap sd = \case EWith{} -> err_accum EZero{} -> err_monoid + EDeepZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid @@ -1392,76 +1355,6 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e - | Refl <- chadD1Id (typeOf e) - , Refl <- chadD1EId (descrList des) + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) = mapExt (const ext) e - where - chadD1Id :: STy a -> D1 a :~: a - chadD1Id STNil = Refl - chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl - chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl - chadD1Id (STScal _) = Refl - chadD1Id STAccum{} = error "accumulators not allowed in source program" - - chadD1EId :: SList STy l -> D1E l :~: l - chadD1EId SNil = Refl - chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl - -accumulateSparse - :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t' - -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil) - -> Ex env TNil -accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of - (_, _, s) | Just Refl <- isDense topty s -> - accum WId SAPHere arg (ENil ext) - (_, SMTScal _, SpScal) -> - accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh - (_, _, SpSparse s) -> - emaybe arg - (ENil ext) - (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w))) - (_, _, SpAbsent) -> - ENil ext - (SAI_D, SMTPair t1 t2, SpPair s1 s2) -> - eunPair arg $ \w1 e1 e2 -> - elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ - accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) - (SAI_S, SMTPair{}, SpPair{}) -> - error "TODO: accumulating into pair inside coproduct unimplemented" - -- There are two different ways this can be accomplished: - -- 1. Ensure we have the requisite ZeroInfo here. This means that an - -- accum-mode variable reference will (if its incoming cotangent is - -- sparse enough) need to store some ZeroInfo fragments computed from - -- the primal (not necessarily the entire primal). Doing this properly, - -- i.e. not just storing a full D1 but only the required ZeroInfo - -- fragments, is possible and not too inefficient but a bit of - -- engineering again. - -- 2. When creating an accumulator, don't initialise it with a generic - -- EZero based on a ZeroInfo, but instead a special "deep zero" based on - -- probably a full D1. This deep zero also initialises Left/Right/Just - -- modelled after the primal. With this, an accumulation needs no zero - -- info whatsoever (!) under the assumption that it receives a cotangent - -- that is compatible with the primal it is propagated back to. - (_, SMTLEither t1 t2, SpLEither s1 s2) -> - elcase arg - (ENil ext) - (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) - (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) - (_, SMTMaybe t, SpMaybe s) -> - emaybe arg - (ENil ext) - (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) - (SAI_D, SMTArr n t, SpArr s) -> - let tn = tTup (sreplicate n tIx) in - elet arg $ - elet (EBuild ext n (EShape ext (evar IZ)) $ - accumulateSparse dense t s - (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) - (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $ - ENil ext - (SAI_S, SMTArr{}, SpArr{}) -> - error "TODO: accumulating into array inside coproduct unimplemented" - -- See the pair case above, same reasoning diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs new file mode 100644 index 0000000..8c7794a --- /dev/null +++ b/src/CHAD/Accum.hs @@ -0,0 +1,45 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD and CHAD.Top. +module CHAD.Accum where + +import AST +import CHAD.Types +import Data +import AST.Env + + +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 130174a..484779e 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -12,9 +12,12 @@ module CHAD.Top where import Analysis.Identity import AST +import AST.Env +import AST.Sparse import AST.SplitLets import AST.Weaken.Auto import CHAD +import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -43,36 +46,22 @@ accumDescr (t `SCons` env) k = accumDescr env $ \des -> if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) else k (des `DPush` (t, Nothing, SMerge)) -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl - reassembleD2E :: Descr env sto + -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, _, SAccum)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) - (ESnd ext (EVar ext (typeOf e) IZ)))) - (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, _, SMerge)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) - (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) - (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) @@ -82,21 +71,22 @@ chad config env (term :: Ex env t) let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (select SAccum descr) $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr VarMap.empty term')) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term') + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') where term' = identityAnalysis env (splitLets term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 83f013d..8b3a8db 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Types where @@ -124,3 +125,18 @@ lemZeroInfoScal STI64 = Refl lemZeroInfoScal STF32 = Refl lemZeroInfoScal STF64 = Refl lemZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/Compile.hs b/src/Compile.hs index 722b432..a5c4fb7 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -45,6 +45,7 @@ import qualified Prelude import Array import AST import AST.Pretty (ppSTy, ppExpr) +import AST.Sparse.Types (isDense) import Compile.Exec import Data import Interpreter.Rep @@ -1002,95 +1003,7 @@ compile' env = \case rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - EAccum _ t prj eidx eval eacc -> do - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a - -- full zero array with the given zero info (for the type SMTArr n t1). - initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () - initZeroArray n t1 v vzi = do - shszname <- genName' "inacshsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) - newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) - emit $ SAsg v (CELit newarrName) - forM_ (initZeroFromMemset t1) $ \f1 -> do - ivar <- genName' "i" - ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts - - -- If something needs to be done to properly initialise this type to - -- zero after memory has already been initialised to all-zero bytes, - -- returns an action that does so. - -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) - initZeroFromMemset SMTNil = Nothing - initZeroFromMemset (SMTPair t1 t2) = - case (initZeroFromMemset t1, initZeroFromMemset t2) of - (Nothing, Nothing) -> Nothing - (mf1, mf2) -> Just $ \v vzi -> do - forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") - forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") - initZeroFromMemset SMTLEither{} = Nothing - initZeroFromMemset SMTMaybe{} = Nothing - initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi - initZeroFromMemset SMTScal{} = Nothing - - let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroZI :: SMTy a -> String -> String -> CompM () - initZeroZI SMTNil _ _ = return () - initZeroZI (SMTPair t1 t2) v vzi = do - initZeroZI t1 (v++".a") (vzi++".a") - initZeroZI t2 (v++".b") (vzi++".b") - initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZeroZI (SMTScal sty) v _ = case sty of - STI32 -> emit $ SAsg v (CELit "0") - STI64 -> emit $ SAsg v (CELit "0l") - STF32 -> emit $ SAsg v (CELit "0.0f") - STF64 -> emit $ SAsg v (CELit "0.0") - - let -- Initialise an uninitialised accumulation value, potentially already - -- with the addend, potentially to zero depending on the nature of the - -- projection. - -- 1. If the projection indexes only through dense monoids before - -- reaching SAPHere, the thing cannot be initialised to zero with - -- only an AcIdx; it would need to model a zero after the addend, - -- which is stupid and redundant. In this case, we return Left: - -- (accumulation value) (AcIdx value) (addend value). - -- The addend is copied, not consumed. (We can't reliably _always_ - -- consume it, so it's not worth trying to do it sometimes.) - -- 2. Otherwise, a sparse monoid is found along the way, and we can - -- initalise the dense prefix of the path to zero by setting the - -- indexed-through sparse value to a sparse zero. Afterwards, the - -- main recursion can proceed further. In this case, we return - -- Right: (accumulation value) (AcIdx value) - -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) - initZeroChunk :: SMTy a -> SAcPrj p a b - -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend - (String -> String -> CompM ()) -- zero initialisation of sparse chunk - initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of - -- reached target before the first sparse constructor - (t1 , SAPHere ) -> Left $ \v _ addend -> do - incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend - emit $ SAsg v (CELit addend) - -- sparse types - (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - -- dense types - (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - f (v++".a") (i++".a") - initZeroZI t2 (v++".b") (i++".b") - (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do - initZeroZI t1 (v++".a") (i++".a") - f (v++".b") (i++".b") - (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - initZeroArray n t1 v (i++".a.b") - linidxvar <- genName' "li" - emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) - f (v++".buf->xs["++linidxvar++"]") (i++".b") - where - applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i - applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i - + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do let -- Add a value (s) into an existing accumulation value (d). If a sparse -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () @@ -1160,67 +1073,55 @@ compile' env = \case accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend - - accumRef (SMTLEither ta tb) prj0 v i addend = do - let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' - SAPRight prj' -> initZeroChunk tb prj' - subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") - tagval = case prj0 of SAPLeft{} -> "1" - SAPRight{} -> "2" - ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend - SAPRight prj' -> accumRef tb prj' subv i addend - case chunkres of - Left densef -> do - ((), stmtsSet) <- scope $ densef subv i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) - stmtsAdd -- TODO: emit check for consistency of tags? - Right sparsef -> do - ((), stmtsInit) <- scope $ sparsef subv i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty - forM_ stmtsAdd emit + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend + + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - case initZeroChunk tj prj' of - Left densef -> do - ((), stmtsSet1) <- scope $ densef (v++".j") i addend - ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) - stmtsAdd1 - Right sparsef -> do - ((), stmtsInit1) <- scope $ sparsef (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i addend + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ v ++ ".buf" ++ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend nameidx <- compileAssign "acidx" env eidx nameval <- compileAssign "acval" env eval @@ -1234,6 +1135,9 @@ compile' env = \case return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] @@ -1247,6 +1151,7 @@ compile' env = \case return $ CEStruct name [] EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" diff --git a/src/Data.hs b/src/Data.hs index e86aaa6..e6978c8 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -8,12 +8,13 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl)) where +module Data (module Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare import Data.GADT.Show import Data.Some +import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) @@ -184,3 +185,8 @@ instance Applicative Bag where instance Semigroup (Bag t) where (<>) = BTwo instance Monoid (Bag t) where mempty = BNone + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) diff --git a/src/Example.hs b/src/Example.hs index d3f6d0d..b320ead 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -162,8 +162,7 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - _ -> undefined + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index b3576ce..ffc2929 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,7 @@ module Interpreter ( ) where import Control.Monad (foldM, join, when, forM_) +import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity @@ -35,6 +36,7 @@ import Debug.Trace import Array import AST import AST.Pretty +import AST.Sparse.Types import Data import Interpreter.Rep @@ -158,14 +160,17 @@ interpret'Rec env = \case initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do + EAccum _ t p e1 sp e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparseD t p accum idx val + accumAddSparseD t p accum idx sp val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b @@ -216,6 +221,19 @@ zeroM typ zi = case typ of STF32 -> 0.0 STF64 -> 0.0 +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + addM :: SMTy t -> Rep t -> Rep t -> Rep t addM typ a b = case typ of SMTNil -> () @@ -256,15 +274,6 @@ withAccum t _ initval f = AcM $ do val <- readAc t accum return (out, val) -newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) -newAcZero typ zi = case typ of - SMTNil -> return () - SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi - SMTLEither{} -> newIORef Nothing - SMTMaybe _ -> newIORef Nothing - SMTArr _ t -> arrayMapM (newAcZero t) zi - SMTScal sty -> numericIsNum sty $ newIORef 0 - newAcDense :: SMTy a -> Rep a -> IO (RepAc a) newAcDense typ val = case typ of SMTNil -> return () @@ -274,22 +283,6 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (_, SAPHere) -> newAcDense typ val - - (SMTPair t1 t2, SAPFst prj') -> - (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) - (SMTPair t1 t2, SAPSnd prj') -> - (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val - - (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx - onehotArray :: Monad m => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" @@ -309,81 +302,67 @@ readAc typ val = case typ of SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val -accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () -accumAddDense typ ref val = case typ of - SMTNil -> return () - SMTPair t1 t2 -> do - accumAddDense t1 (fst ref) (fst val) - accumAddDense t2 (snd ref) (snd val) - SMTLEither{} -> - case val of - Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - SMTMaybe{} -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - SMTArr _ t1 -> - forM_ [0 .. arraySize ref - 1] $ \i -> - accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) - SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) - -accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s () -accumAddSparseD typ prj ref idx val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref val +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val - (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) (SMTArr n t1, SAPArrIdx prj') -> let (arrindex', idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ref linindex = toLinearIndex arrsh arrindex - in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val - -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref val - - (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val - - (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - - (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) - - (SMTArr n t1, SAPArrIdx prj') -> - let ((arrindex', ziarr), idx') = idx - arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = arrayShape ziarr - linindex = toLinearIndex arrsh arrindex - in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) +-- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = -- Try modifying what's already in ref. The 'join' makes the snd diff --git a/src/Language.hs b/src/Language.hs index 63279df..4e6d604 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -17,6 +17,7 @@ module Language ( import Array import AST +import AST.Sparse.Types import AST.Types import CHAD.Types import Data @@ -176,7 +177,10 @@ with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum with a (n :-> b) = NEWith (knownMTy @t) a n b accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownMTy p a b c +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 92792b3..be98ccf 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import AST.Sparse.Types import CHAD.Types import Data @@ -76,7 +77,7 @@ data NExpr env t where -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a @@ -221,7 +222,7 @@ fromNamedExpr val = \case NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s diff --git a/src/Simplify.hs b/src/Simplify.hs index d3b850f..74b6601 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE QuasiQuotes #-} @@ -19,13 +21,14 @@ import Control.Monad (ap) import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) import Debug.Trace import AST import AST.Count import AST.Pretty +import AST.Sparse.Types +import AST.UnMonoid (acPrjCompose) import Data import Simplify.TH @@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id) smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) smReconstruct core = SM (\ctx -> (Any False, ctx core)) -tellActed :: SM tenv tt env t () -tellActed = SM (\_ -> (Any True, ())) +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) -- more convenient in practice -acted :: SM tenv tt env t a -> SM tenv tt env t a +acted :: ActedMonad m => m a -> m a acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) -acted' :: (Any, a) -> (Any, a) -acted' (_, x) = (Any True, x) - -liftActed :: (Any, a) -> SM tenv tt env t a -liftActed pair = SM $ \_ -> pair - simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do @@ -167,10 +176,10 @@ simplify'Rec = \case ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 (ELet _ rhs body) acc -> + EAccum _ t p e1 sp (ELet _ rhs body) acc -> acted $ simplify' $ ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) -- let () = e in () ~> e ELet _ e1 (ENil _) | STNil <- typeOf e1 -> @@ -194,6 +203,9 @@ simplify'Rec = \case EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 + -- TODO: more array shape + EShape _ (EBuild _ _ e _) -> acted $ simplify' e + -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) @@ -222,23 +234,40 @@ simplify'Rec = \case acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules - EAccum _ t p e1 e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2') + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') (acted $ return (ENil ext)) - (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2') + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\e -> acted $ return e) - (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -302,8 +331,9 @@ simplify'Rec = \case e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') - EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s cheapExpr :: Expr x env t -> Bool @@ -353,8 +383,9 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b ERecompute _ e -> hasAdds e - EAccum _ _ _ _ _ _ -> True + EAccum _ _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -373,51 +404,161 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm dense env p a b where - OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b -deriving instance Show (OneHotTerm dense env p a b) - -simplifyOneHotTerm :: OneHotTerm dense env p a b - -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) - -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r) - -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do - val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 - case val1' of - EZero{} -> kzero - EOneHot _ t2 prj2 idx2 val2 - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do - tellActed -- record, whatever happens later, that we've modified something - concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k - _ -> case prj1 of - SAPHere -> ktriv val1 - _ -> k (OneHotTerm dense t1 prj1 idx1 val1) +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' (a', b') -> return $ EPair ext a' b' recogniseMonoid typ@(SMTLEither t1 t2) expr = case expr of - ELNil{} -> acted' $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e _ -> return expr recogniseMonoid typ@(SMTMaybe t1) expr = case expr of - ENothing{} -> acted' $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e _ -> return expr recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted' $ do + acted $ do e' <- recogniseMonoid t e return $ ELet ext e' $ @@ -426,61 +567,21 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = (ENil ext)) (EVar ext (fromSMTy t) IZ) recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SStillDense dense -> SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r -concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of - (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2) - (SAI_S, _, SAPHere) -> k prj2 idx2 - - (SAI_D, SMTPair a _, SAPFst prj1') -> - concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> - k (SAPFst prj12) idx12 - (SAI_S, SMTPair a _, SAPFst prj1') -> - concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SAI_D, SMTPair _ b, SAPSnd prj1') -> - concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> - k (SAPSnd prj12) idx12 - (SAI_S, SMTPair _ b, SAPSnd prj1') -> - concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - - (_, SMTLEither a _, SAPLeft prj1') -> - concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (_, SMTLEither _ b, SAPRight prj1') -> - concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (_, SMTMaybe a, SAPJust prj1') -> - concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - -- yes, twice the same code, but we need a concrete denseness indicator to - -- reduce AcIdx (the only difference between the dense and sparse versions is - -- whether there extra info also contains an array shape, and this code - -- handles the extra info uniformly) - (SAI_D, SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (SAI_S, SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - -reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a) +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) reduceAcIdx topty topprj e = case (topty, topprj) of (_, SAPHere) -> ENil ext (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) - (SMTLEither{}, SAPLeft{}) -> e - (SMTLEither{}, SAPRight{}) -> e - (SMTMaybe{}, SAPJust{}) -> e + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e (SMTArr _ t, SAPArrIdx p) -> eunPair e $ \_ e1 e2 -> EPair ext (efst e1) (reduceAcIdx t p e2) diff --git a/test/Main.hs b/test/Main.hs index 1b83a2e..d79e63f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -435,11 +435,22 @@ gen_neural = do lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx + term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -502,22 +513,22 @@ tests_Compile = testGroup "Compile" ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) (pair 0.0 0.0) $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TMaybe (TPair R R)) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil @@ -556,9 +567,7 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum @@ -566,6 +575,8 @@ tests_AD = testGroup "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x -- cgit v1.2.3-70-g09d2 From 2b00a57f565a42b1079a071e2db630ba22c7120d Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 00:07:48 +0200 Subject: TODO deep zero in accum + fix warnings --- test/Main.hs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/Main.hs b/test/Main.hs index d79e63f..8da7598 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -577,6 +577,30 @@ tests_AD = testGroup "AD" ,adTest "build1-idx" term_build1_idx + ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#x ! pair nil #i) $ + 3 * fst_ #p + 2 * snd_ #p + + ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + + ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ + let_ #n (snd_ (shape #arr)) $ + let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $ + if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x)) + (pair (inr (3 * #x)) (exp #x)))) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#b ! pair nil #i) $ + case_ (fst_ #p) + (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p) + (#b :-> #b * 4) + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x -- cgit v1.2.3-70-g09d2 From 62639875102decae2bb96b3847ae48db5d1f8fd0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:09:56 +0200 Subject: Complete pattern matches --- src/AST/Count.hs | 1 + src/AST/SplitLets.hs | 1 + src/Analysis/Identity.hs | 7 +++++++ src/ForwardAD/DualNumbers.hs | 1 + 4 files changed, 10 insertions(+) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 05be524..ca4d7ab 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -136,6 +136,7 @@ occCountGeneral onehot unpush alter many = go WId EWith _ _ a b -> re a <> re1 b EAccum _ _ _ a _ b e -> re a <> re b <> re e EZero _ _ e -> re e + EDeepZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b EError{} -> mempty diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 2dad17a..dcaf82f 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -65,6 +65,7 @@ splitLets' = \sub -> \case EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 2fd321d..b54946b 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -320,6 +320,13 @@ idana env expr = case expr of res <- genIds (fromSMTy t) pure (res, EZero res t e1') + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index a6d5ec8..3ab08af 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -190,6 +190,7 @@ dfwdDN = \case EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid -- cgit v1.2.3-70-g09d2 From 3db7d00b3248d746aa99f57b117d5722cbe90df0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:10:30 +0200 Subject: Give DeepZero to With --- src/AST/Accum.hs | 8 ++++++++ src/CHAD.hs | 2 +- src/CHAD/Accum.hs | 24 +++++++++++++++++++++++- src/CHAD/Types.hs | 7 +++++++ 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 619c2b1..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -102,6 +102,14 @@ type family DeepZeroInfo t where DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) DeepZeroInfo (TScal t) = TNil +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/CHAD.hs b/src/CHAD.hs index 621aa3e..9fa7f9a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1341,7 +1341,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of in RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in - EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: (#body :++: #p) :++: #tl)) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index 8c7794a..7212232 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -22,11 +22,33 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators _ SNil e = e makeAccumulators w (t `SCons` envpro) e = makeAccumulators (WPop w) envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 8b3a8db..e061588 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -126,6 +126,13 @@ lemZeroInfoScal STF32 = Refl lemZeroInfoScal STF64 = Refl lemZeroInfoScal STBool = Refl +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + d1Identity :: STy t -> D1 t :~: t d1Identity = \case STNil -> Refl -- cgit v1.2.3-70-g09d2 From 58f68a4d077c2d58c3974ad12853207512277a33 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:10:50 +0200 Subject: Put smart accumulator redirection behind config flag --- src/CHAD.hs | 3 ++- src/CHAD/Types.hs | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/CHAD.hs b/src/CHAD.hs index 9fa7f9a..3399de2 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1310,7 +1310,8 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) SAccum - | Just (VIArr i _) <- argids + | chcSmartWith ?config + , Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap , Just Refl <- testEquality foundTy (STAccum (d2M argty)) , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index e061588..44ac20e 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -97,6 +97,8 @@ data CHADConfig = CHADConfig chcCaseArrayAccum :: Bool , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool } deriving (Show) @@ -105,12 +107,14 @@ defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False , chcArgArrayAccum = False + , chcSmartWith = False } chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True - , chcArgArrayAccum = True } + , chcArgArrayAccum = True + , chcSmartWith = True } ------------------------------------ LEMMAS ------------------------------------ -- cgit v1.2.3-70-g09d2 From fe80b31555c27f038b20eb84eb1e747781d7c76b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:11:12 +0200 Subject: Don't destroy effects in sparse plus --- src/AST/Sparse.hs | 20 ++++++++++++-------- src/CHAD.hs | 3 ++- test/Main.hs | 15 +++++++++------ 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index f0a1f2a..0c5bdb0 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -66,6 +66,9 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g) withInj2 Noinj _ _ = Noinj withInj2 _ Noinj _ = Noinj +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + -- | This function produces quadratically-sized code in the presence of nested -- dynamic sparsity. TODO can this be improved? sparsePlusS @@ -77,16 +80,17 @@ sparsePlusS -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) -> r) -> r --- nil override -sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext) +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) -- simplifications sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> - k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b) + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> - k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext)) + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = let ta = applySparse sp1 (fromSMTy t) in @@ -144,13 +148,13 @@ sparsePlusS _ _ t sp1 sp2 k = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) -- handle absents -sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) sparsePlusS ST _ t SpAbsent sp2 k = - k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b) + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) -sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a) +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) sparsePlusS _ ST t sp1 SpAbsent k = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a) + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) -- double sparse yields sparse sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = diff --git a/src/CHAD.hs b/src/CHAD.hs index 3399de2..9a08457 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -404,7 +404,8 @@ subenvPlus :: SBool req1 -> SBool req2 -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r -subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext) +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> diff --git a/test/Main.hs b/test/Main.hs index 8da7598..5ec9dbc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -451,6 +451,14 @@ term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ idx0 $ sum1i $ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -583,12 +591,7 @@ tests_AD = testGroup "AD" let_ #p (#x ! pair nil #i) $ 3 * fst_ #p + 2 * snd_ #p - ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ - let_ #n (snd_ (shape #x)) $ - idx0 $ sum1i $ build1 #n $ #i :-> - case_ (#x ! pair nil #i) - (#a :-> #a * 2) - (#b :-> #b * 3) + ,adTest "idx-coprod" $ term_idx_coprod ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ let_ #n (snd_ (shape #arr)) $ -- cgit v1.2.3-70-g09d2 From 5a0a5e9ef69926265289ae5229e68060a7c77a27 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:29:16 +0200 Subject: Don't introduce sparsity if zero is cheap --- src/AST/Sparse.hs | 35 ++++++++++++++++++++++++++++++----- src/CHAD.hs | 37 +++++++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 0c5bdb0..34a398f 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -1,8 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where @@ -46,6 +47,24 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + data Injection sp a b where -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that -- 'sparsePlusS' can provide injections even if the caller doesn't require @@ -149,12 +168,18 @@ sparsePlusS _ _ t sp1 sp2 k -- handle absents sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) -sparsePlusS ST _ t SpAbsent sp2 k = - k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) -sparsePlusS _ ST t sp1 SpAbsent k = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) -- double sparse yields sparse sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = diff --git a/src/CHAD.hs b/src/CHAD.hs index 9a08457..143376a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -423,18 +423,31 @@ subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) (weakenExpr WSink e2)) (ESnd ext (EVar ext (typeOf e1) IZ))) -subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k = - subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> - k (SEYes (SpSparse sp1) sub3) - (withInj minj13 $ \inj13 -> - \e1 -> eunPair e1 $ \_ e1a e1b -> - EPair ext (inj13 e1a) (EJust ext e1b)) - (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) - (\e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> -- cgit v1.2.3-70-g09d2 From b20a3cf72522a88d73ab1d6f03c13e5705c7ab8e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 11:24:30 +0200 Subject: test-framework: Correct line count when collapsing with nested subgroups --- test-framework/Test/Framework.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index e0dc4b3..1b2b7d7 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -190,7 +190,7 @@ runTests options = \tree' -> "\x1B[32mOK\x1B[0m" ++ prettyDuration False (realToFrac (diffUTCTime endtm starttm)) return (Just 1) - _ -> return mlns + _ -> return ((+1) <$> mlns) go indent path (Resource make cleanup fun) = do value <- liftIO make success <- go indent path (fun value) -- cgit v1.2.3-70-g09d2 From a45bf0fd84d8e604613e9e557ae80143f1a41004 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 11:25:13 +0200 Subject: test: Test both default and accum configs --- test/Main.hs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/Main.hs b/test/Main.hs index 5ec9dbc..3847920 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -305,7 +305,9 @@ adTestGen name expr envGenerator = testGroupCollapse name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS - ,adTestGenChad env envGenerator expr exprS primalSfun] + ,testGroup "chad" + [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun + ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]] adTestGenPrimal :: SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R @@ -336,19 +338,19 @@ adTestGenFwd env envGenerator exprS = diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 -adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) +adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree -adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 - dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> - testProperty "chad" $ property $ do + testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- pack Text for less GC pressure (these values are retained for some reason) -- cgit v1.2.3-70-g09d2 From 6d25e87e6f703395038d23aaff225aa502283519 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 13:53:00 +0200 Subject: test: Diligently check UnMonoid correctness --- test/Main.hs | 58 +++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/test/Main.hs b/test/Main.hs index 3847920..0a57cbf 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -345,17 +345,21 @@ adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 + dtermChadSUS = simplifyFix $ unMonoid dtermChadS dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 + dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> + withCompiled env dtermSChadSUS $ \dcompSChadSUS -> testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - -- pack Text for less GC pressure (these values are retained for some reason) + -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0))) diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -365,17 +369,21 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS - - (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) - let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + + (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input) + let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) @@ -383,17 +391,21 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outCompSChadS closeIsh outPrimal + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outCompSChadSUS closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansCompSChadS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansCompSChadSUS closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -- cgit v1.2.3-70-g09d2 From 48e4977f3e0a88ff24410987b80bf6003c45dfb7 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 13:56:42 +0200 Subject: Don't destroy effects in UnMonoid --- src/AST/UnMonoid.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index d498aaa..ef01bf8 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -60,7 +60,7 @@ unMonoid = \case EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -zero SMTNil _ = ENil ext +zero SMTNil e = elet e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) @@ -74,7 +74,7 @@ zero (SMTScal t) _ = case t of STF64 -> EConst ext STF64 0.0 deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t -deepZero SMTNil _ = ENil ext +deepZero SMTNil e = elet e $ ENil ext deepZero (SMTPair t1 t2) e = ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) @@ -95,7 +95,8 @@ deepZero (SMTScal t) _ = case t of STF64 -> EConst ext STF64 0.0 plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -plus SMTNil _ _ = ENil ext +-- don't destroy the effects! +plus SMTNil a b = elet a $ elet (weakenExpr WSink b) $ ENil ext plus (SMTPair t1 t2) a b = let t = STPair (fromSMTy t1) (fromSMTy t2) in ELet ext a $ -- cgit v1.2.3-70-g09d2 From a4b3eb76acbec30ffeae119a4dc6e4c9f64396fe Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 14:10:47 +0200 Subject: Some more effects to not ignore --- src/AST/Sparse.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 34a398f..93258b7 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -16,7 +16,7 @@ import Data (SBool(..)) sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' -sparsePlus _ SpAbsent _ _ = ENil ext +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = -- cgit v1.2.3-70-g09d2