aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs50
1 files changed, 49 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a37edff..298d964 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1083,7 +1083,55 @@ drev des accumMap sd = \case
(#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap{} -> error "TODO: CHAD EMap"
+ EMap _ ef (earr :: Expr _ _ (TArr n a))
+ | SpArr sdElt <- sd
+ , let STArr ndim t1 = typeOf earr
+ t2 = typeOf ef ->
+ drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 ->
+ case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds
+ ttape = typeOf ef1tape
+ library = #d1env (desD1E des)
+ &. #a0 (bindingsBinds ea0)
+ &. #atapebinds (subList (bindingsBinds ea0) easubtape)
+ &. #propr (d1e provars)
+ &. #x (d1 t1 `SCons` SNil)
+ &. #parr (STArr ndim (d1 t1) `SCons` SNil)
+ &. #tapearr (STArr ndim ttape `SCons` SNil)
+ &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil)
+ &. #dy (applySparse sdElt (d2 t2) `SCons` SNil)
+ &. #tape (ttape `SCons` SNil)
+ &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #pro (d2ace provars)
+ in
+ subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a ->
+ Ret (bconcat ea0 proPrimalBinds'
+ `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1
+ `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env))
+ (letBinds ef0 $
+ EPair ext ef1 ef1tape))
+ (EVar ext (STArr ndim (d1 t1)) IZ)
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ))
+ (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars))))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ)))
+ subfa
+ (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout) $
+ emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $
+ elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv)
+ (#tape :++: #dy :++: #dytape :++: #pro :++: layout))
+ ef2)
+ (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ))
+ (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $
+ plus_f_a
+ (ESnd ext (evar IZ))
+ (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout))
+ (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ))
+ ea2)))
+ }
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd