summaryrefslogtreecommitdiff
path: root/src/ForwardAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD')
-rw-r--r--src/ForwardAD/DualNumbers.hs39
1 files changed, 32 insertions, 7 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 056fcb3..7d47e6d 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -98,6 +98,12 @@ dop = \case
in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
(EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
+zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
+zeroScalarConst STF32 = EConst ext STF32 0.0
+zeroScalarConst STF64 = EConst ext STF64 0.0
+zeroScalarConst STI32 = EConst ext STI32 0
+zeroScalarConst STI64 = EConst ext STI64 0
+
dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
EVar _ t i -> EVar ext (dn t) (convIdx i)
@@ -131,13 +137,8 @@ dfwdDN = \case
(ESum1Inner ext (dfwdDN e))
EUnit _ e -> EUnit ext (dfwdDN e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
- EMaximum1Inner _ e ->
- let STArr n (STScal t) = typeOf e
- in scalTyCase t
- -- TODO: do roughly the same as what CHAD does, but forward
- (_ (dfwdDN e))
- _
- EMinimum1Inner _ e -> EMinimum1Inner ext (dfwdDN e)
+ EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
@@ -165,3 +166,27 @@ dfwdDN = \case
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
+
+ deriv_extremum :: ScalIsNumeric t ~ True
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t)))
+ deriv_extremum extremum e =
+ let STArr (SS n) (STScal t) = typeOf e
+ t2 = STPair (STScal t) (STScal t)
+ ta2 = STArr (SS n) t2
+ tIxN = tTup (sreplicate (SS n) tIx)
+ in scalTyCase t
+ (ELet ext (dfwdDN e) $
+ ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $
+ ezip (EVar ext (STArr n (STScal t)) IZ)
+ (ESum1Inner ext
+ {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -}
+ (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $
+ ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext
+ (EFst ext (EVar ext t2 IZ))
+ (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ)))
+ (EFst ext (EVar ext tIxN (IS IZ)))))))
+ (ESnd ext (EVar ext t2 (IS IZ)))
+ (zeroScalarConst t))))
+ (EMaximum1Inner ext (dfwdDN e))