summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-11 09:35:35 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-11 09:35:35 +0100
commit281229b7bf307132a428dde1b171e1db86637238 (patch)
treef5fbe70ef80132ac1ea084cd5731998a9a55f677
parenta46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (diff)
Make EBuild derivative aware of zero cotangent arrays
-rw-r--r--src/AST.hs14
-rw-r--r--src/CHAD.hs56
2 files changed, 43 insertions, 27 deletions
diff --git a/src/AST.hs b/src/AST.hs
index e7dde90..9ad0d4d 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -408,6 +408,20 @@ ezip a b =
(EIdx ext (EVar ext (STArr n t2) (IS IZ))
(EVar ext (tTup (sreplicate n tIx)) IZ))
+eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
+eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
+
+-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
+eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eshapeEmpty SZ _ = EConst ext STBool False
+eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
+eshapeEmpty (SS n) e =
+ ELet ext e $
+ EOp ext OAnd (EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
+ (EConst ext STI64 0)))
+ (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n)
arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext)
where
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8b9f17a..45fcc82 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1016,33 +1016,35 @@ drev des = \case
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvCompose subMergeUsed proSub)
(let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
- weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (tapety `SCons` SNil)
- &. #ix (shty `SCons` SNil)
- &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
- &. #tapearr (STArr ndim tapety `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
+ eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ)))
+ (zeroTup envPro)
+ (ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
+ &. #pro (d2ace envPro)
+ &. #etape (subList (bindingsBinds e0) subtapeE)
+ &. #prerebinds prerebinds
+ &. #tape (tapety `SCons` SNil)
+ &. #ix (shty `SCons` SNil)
+ &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
+ &. #tapearr (STArr ndim tapety `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds e0) subtapeE))
+ e2))
}}
EUnit _ e