diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-11 09:35:35 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-11 09:35:35 +0100 |
commit | 281229b7bf307132a428dde1b171e1db86637238 (patch) | |
tree | f5fbe70ef80132ac1ea084cd5731998a9a55f677 | |
parent | a46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (diff) |
Make EBuild derivative aware of zero cotangent arrays
-rw-r--r-- | src/AST.hs | 14 | ||||
-rw-r--r-- | src/CHAD.hs | 56 |
2 files changed, 43 insertions, 27 deletions
@@ -408,6 +408,20 @@ ezip a b = (EIdx ext (EVar ext (STArr n t2) (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) +eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a +eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) + +-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). +eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eshapeEmpty SZ _ = EConst ext STBool False +eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) +eshapeEmpty (SS n) e = + ELet ext e $ + EOp ext OAnd (EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) + (EConst ext STI64 0))) + (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n) arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext) where diff --git a/src/CHAD.hs b/src/CHAD.hs index 8b9f17a..45fcc82 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1016,33 +1016,35 @@ drev des = \case (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvCompose subMergeUsed proSub) (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (tapety `SCons` SNil) - &. #ix (shty `SCons` SNil) - &. #darr (STArr ndim (d2 eltty) `SCons` SNil) - &. #tapearr (STArr ndim tapety `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) + eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ))) + (zeroTup envPro) + (ESnd ext $ + uninvertTup (d2e envPro) (STArr ndim STNil) $ + makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) + &. #pro (d2ace envPro) + &. #etape (subList (bindingsBinds e0) subtapeE) + &. #prerebinds prerebinds + &. #tape (tapety `SCons` SNil) + &. #ix (shty `SCons` SNil) + &. #darr (STArr ndim (d2 eltty) `SCons` SNil) + &. #tapearr (STArr ndim tapety `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des))) + (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) + .> wPro (subList (bindingsBinds e0) subtapeE)) + e2)) }} EUnit _ e |