diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/AST/Accum.hs | 8 | ||||
-rw-r--r-- | src/CHAD.hs | 2 | ||||
-rw-r--r-- | src/CHAD/Accum.hs | 24 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 7 |
4 files changed, 39 insertions, 2 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 619c2b1..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -102,6 +102,14 @@ type family DeepZeroInfo t where DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) DeepZeroInfo (TScal t) = TNil +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/CHAD.hs b/src/CHAD.hs index 621aa3e..9fa7f9a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1341,7 +1341,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of in RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in - EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: (#body :++: #p) :++: #tl)) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index 8c7794a..7212232 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -22,11 +22,33 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators _ SNil e = e makeAccumulators w (t `SCons` envpro) e = makeAccumulators (WPop w) envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 8b3a8db..e061588 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -126,6 +126,13 @@ lemZeroInfoScal STF32 = Refl lemZeroInfoScal STF64 = Refl lemZeroInfoScal STBool = Refl +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + d1Identity :: STy t -> D1 t :~: t d1Identity = \case STNil -> Refl |