summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-18 10:10:30 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-18 10:10:30 +0200
commit3db7d00b3248d746aa99f57b117d5722cbe90df0 (patch)
treed01a0f9c1e95de07d6dc3fcb8f6895bf94a77165 /src/CHAD
parent62639875102decae2bb96b3847ae48db5d1f8fd0 (diff)
Give DeepZero to With
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Accum.hs24
-rw-r--r--src/CHAD/Types.hs7
2 files changed, 30 insertions, 1 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index 8c7794a..7212232 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -22,11 +22,33 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
+d2deepZeroInfo STNil _ = ENil ext
+d2deepZeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
+d2deepZeroInfo (STEither a b) e =
+ ECase ext e
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STLEither a b) e =
+ elcase e
+ (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STMaybe a) e =
+ emaybe e
+ (ENothing ext (tDeepZeroInfo (d2M a)))
+ (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
+d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
+d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators _ SNil e = e
makeAccumulators w (t `SCons` envpro) e =
makeAccumulators (WPop w) envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
+ EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 8b3a8db..e061588 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -126,6 +126,13 @@ lemZeroInfoScal STF32 = Refl
lemZeroInfoScal STF64 = Refl
lemZeroInfoScal STBool = Refl
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl