diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 /src/CHAD.hs | |
parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 75 |
1 files changed, 45 insertions, 30 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index a5a5719..be308cd 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -846,18 +846,18 @@ drev des = \case (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ))) + (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in + EMaybe ext (zeroTup envPro) (ESnd ext $ uninvertTup (d2e envPro) (STArr ndim STNil) $ makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ -- the cotangent for this element ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) (EVar ext shty IZ)) $ -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) (EVar ext shty (IS IZ))) $ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ in letBinds rebinds $ @@ -865,17 +865,19 @@ drev des = \case &. #pro (d2ace envPro) &. #etape (subList (bindingsBinds e0) subtapeE) &. #prerebinds prerebinds - &. #tape (tapety `SCons` SNil) - &. #ix (shty `SCons` SNil) - &. #darr (STArr ndim (d2 eltty) `SCons` SNil) - &. #tapearr (STArr ndim tapety `SCons` SNil) - &. #sh (shty `SCons` SNil) + &. #tape (auto1 @(Tape e_tape)) + &. #ix (auto1 @shty) + &. #darr (auto1 @(TArr ndim (D2 eltty))) + &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) + &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) + &. #sh (auto1 @shty) &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des))) (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) .> wPro (subList (bindingsBinds e0) subtapeE)) - e2)) + e2) + (EVar ext (d2 (STArr ndim eltty)) IZ)) }} EUnit _ e @@ -884,8 +886,11 @@ drev des = \case subtape (EUnit ext e1) sub - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy WSink) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) EReplicate1Inner _ en e -- We're allowed to ignore en2 here because the output of 'ei' is discrete. @@ -896,11 +901,14 @@ drev des = \case subtape (EReplicate1Inner ext en1 e1) sub - (ELet ext (EFold1Inner ext Commut - (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ - weakenExpr (WCopy WSink) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EJust ext (EFold1Inner ext Commut + (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext eltty) + (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) EIdx0 _ e | Ret e0 subtape e1 sub e2 <- drev des e @@ -909,7 +917,7 @@ drev des = \case subtape (EIdx0 ext e1) sub - (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $ + (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" @@ -971,10 +979,13 @@ drev des = \case (SEYes (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (ELet ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (d2 t)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EJust ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (EVar ext (d2 (STArr n t)) IZ)) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e @@ -1010,13 +1021,17 @@ drev des = \case (SEYes (SEYes subtape)) (EVar ext at' IZ) sub - (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ - ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))) - (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) - (EZero ext t)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EJust ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) + (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (EZero ext t))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) + (EVar ext (d2 at') IZ)) data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) |