summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
commita00234388d1b4e14481067d030bf90031258b756 (patch)
tree501b6778fc5779ce220aba1e22f56ae60f68d970 /src/CHAD.hs
parent7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff)
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs75
1 files changed, 45 insertions, 30 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a5a5719..be308cd 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -846,18 +846,18 @@ drev des = \case
(emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvCompose subMergeUsed proSub)
- (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ)))
+ (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
+ EMaybe ext
(zeroTup envPro)
(ESnd ext $
uninvertTup (d2e envPro) (STArr ndim STNil) $
makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $
-- the cotangent for this element
ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
(EVar ext shty IZ)) $
-- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))
(EVar ext shty (IS IZ))) $
let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
in letBinds rebinds $
@@ -865,17 +865,19 @@ drev des = \case
&. #pro (d2ace envPro)
&. #etape (subList (bindingsBinds e0) subtapeE)
&. #prerebinds prerebinds
- &. #tape (tapety `SCons` SNil)
- &. #ix (shty `SCons` SNil)
- &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
- &. #tapearr (STArr ndim tapety `SCons` SNil)
- &. #sh (shty `SCons` SNil)
+ &. #tape (auto1 @(Tape e_tape))
+ &. #ix (auto1 @shty)
+ &. #darr (auto1 @(TArr ndim (D2 eltty)))
+ &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty))))
+ &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
+ &. #sh (auto1 @shty)
&. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des)))
(#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)
.> wPro (subList (bindingsBinds e0) subtapeE))
- e2))
+ e2)
+ (EVar ext (d2 (STArr ndim eltty)) IZ))
}}
EUnit _ e
@@ -884,8 +886,11 @@ drev des = \case
subtape
(EUnit ext e1)
sub
- (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
- weakenExpr (WCopy WSink) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+ (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))
EReplicate1Inner _ en e
-- We're allowed to ignore en2 here because the output of 'ei' is discrete.
@@ -896,11 +901,14 @@ drev des = \case
subtape
(EReplicate1Inner ext en1 e1)
sub
- (ELet ext (EFold1Inner ext Commut
- (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (EZero ext eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
- weakenExpr (WCopy WSink) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EJust ext (EFold1Inner ext Commut
+ (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext eltty)
+ (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+ (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
EIdx0 _ e
| Ret e0 subtape e1 sub e2 <- drev des e
@@ -909,7 +917,7 @@ drev des = \case
subtape
(EIdx0 ext e1)
sub
- (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $
+ (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
weakenExpr (WCopy WSink) e2)
EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
@@ -971,10 +979,13 @@ drev des = \case
(SEYes (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (ELet ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
- (EVar ext (STArr n (d2 t)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EJust ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
+ (EVar ext (STArr n (d2 t)) IZ))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ (EVar ext (d2 (STArr n t)) IZ))
EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
@@ -1010,13 +1021,17 @@ drev des = \case
(SEYes (SEYes subtape))
(EVar ext at' IZ)
sub
- (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
- ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))))
- (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
- (EZero ext t)) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EJust ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
+ (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (EZero ext t))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
+ (EVar ext (d2 at') IZ))
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)