summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 2f05807..a08fe80 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1069,6 +1069,9 @@ drev des = \case
(EVar ext (STArr n (d2 t)) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) e2)
+ EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -1086,3 +1089,24 @@ drev des = \case
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
+
+ deriv_extremum :: ScalIsNumeric t' ~ True
+ => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
+ -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
+ deriv_extremum extremum e
+ | Ret e0 subtape e1 sub e2 <- drev des e
+ , at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYes (SEYes subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))))
+ (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
+ (EZero t)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)