aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-05 21:52:53 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-05 21:52:53 +0100
commit38150f4f9792156d8c59439fe47ecb69a0a0e00b (patch)
tree873bcfb4c952cf14ce15414e6f601cd1a9186346
parente08936de193f71ea83b472fc9a2eaf77eb84f11b (diff)
Implement D[map]
-rw-r--r--src/CHAD.hs50
-rw-r--r--src/Language.hs7
-rw-r--r--src/Language/AST.hs2
3 files changed, 52 insertions, 7 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a37edff..298d964 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1083,7 +1083,55 @@ drev des accumMap sd = \case
(#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap{} -> error "TODO: CHAD EMap"
+ EMap _ ef (earr :: Expr _ _ (TArr n a))
+ | SpArr sdElt <- sd
+ , let STArr ndim t1 = typeOf earr
+ t2 = typeOf ef ->
+ drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 ->
+ case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds
+ ttape = typeOf ef1tape
+ library = #d1env (desD1E des)
+ &. #a0 (bindingsBinds ea0)
+ &. #atapebinds (subList (bindingsBinds ea0) easubtape)
+ &. #propr (d1e provars)
+ &. #x (d1 t1 `SCons` SNil)
+ &. #parr (STArr ndim (d1 t1) `SCons` SNil)
+ &. #tapearr (STArr ndim ttape `SCons` SNil)
+ &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil)
+ &. #dy (applySparse sdElt (d2 t2) `SCons` SNil)
+ &. #tape (ttape `SCons` SNil)
+ &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #pro (d2ace provars)
+ in
+ subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a ->
+ Ret (bconcat ea0 proPrimalBinds'
+ `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1
+ `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env))
+ (letBinds ef0 $
+ EPair ext ef1 ef1tape))
+ (EVar ext (STArr ndim (d1 t1)) IZ)
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ))
+ (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars))))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ)))
+ subfa
+ (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout) $
+ emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $
+ elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv)
+ (#tape :++: #dy :++: #dytape :++: #pro :++: layout))
+ ef2)
+ (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ))
+ (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $
+ plus_f_a
+ (ESnd ext (evar IZ))
+ (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout))
+ (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ))
+ ea2)))
+ }
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd
diff --git a/src/Language.hs b/src/Language.hs
index c1a6248..4886ddc 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -108,12 +108,7 @@ build n a (v :-> b) = NEBuild n a v b
map_ :: forall n a b env name. (KnownNat n, KnownTy a)
=> (Var name a :-> NExpr ('(name, a) : env) b)
-> NExpr env (TArr n a) -> NExpr env (TArr n b)
-map_ (v :-> a) b
- | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) =
- let_ #arg b $
- build knownNat (shape #arg) $ #i :->
- let_ v (#arg ! #i) $
- NEDrop (SS SZ) (NEDrop (SS SZ) a)
+map_ (v :-> a) b = NEMap v a b
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index a3b8130..3d6ede5 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -52,6 +52,7 @@ data NExpr env t where
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
+ NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t)
NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
@@ -216,6 +217,7 @@ fromNamedExpr val = \case
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
+ NEMap n a b -> EMap ext (lambda val n a) (go b)
NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)