summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-18 10:10:30 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-18 10:10:30 +0200
commit3db7d00b3248d746aa99f57b117d5722cbe90df0 (patch)
treed01a0f9c1e95de07d6dc3fcb8f6895bf94a77165
parent62639875102decae2bb96b3847ae48db5d1f8fd0 (diff)
Give DeepZero to With
-rw-r--r--src/AST/Accum.hs8
-rw-r--r--src/CHAD.hs2
-rw-r--r--src/CHAD/Accum.hs24
-rw-r--r--src/CHAD/Types.hs7
4 files changed, 39 insertions, 2 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 619c2b1..988a450 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -102,6 +102,14 @@ type family DeepZeroInfo t where
DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
DeepZeroInfo (TScal t) = TNil
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 621aa3e..9fa7f9a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1341,7 +1341,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
in
RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
- EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
weakenExpr (autoWeak library
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: (#body :++: #p) :++: #tl))
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index 8c7794a..7212232 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -22,11 +22,33 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
+d2deepZeroInfo STNil _ = ENil ext
+d2deepZeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
+d2deepZeroInfo (STEither a b) e =
+ ECase ext e
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STLEither a b) e =
+ elcase e
+ (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STMaybe a) e =
+ emaybe e
+ (ENothing ext (tDeepZeroInfo (d2M a)))
+ (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
+d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
+d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators _ SNil e = e
makeAccumulators w (t `SCons` envpro) e =
makeAccumulators (WPop w) envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
+ EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 8b3a8db..e061588 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -126,6 +126,13 @@ lemZeroInfoScal STF32 = Refl
lemZeroInfoScal STF64 = Refl
lemZeroInfoScal STBool = Refl
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl