diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 2f05807..a08fe80 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1069,6 +1069,9 @@ drev des = \case (EVar ext (STArr n (d2 t)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e + EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1086,3 +1089,24 @@ drev des = \case err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s + + deriv_extremum :: ScalIsNumeric t' ~ True + => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) + -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) + deriv_extremum extremum e + | Ret e0 subtape e1 sub e2 <- drev des e + , at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYes (SEYes subtape)) + (EVar ext at' IZ) + sub + (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))) + (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) + (EZero t)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) |