diff options
Diffstat (limited to 'src/ForwardAD')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 056fcb3..7d47e6d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -98,6 +98,12 @@ dop = \case in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) +zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) +zeroScalarConst STF32 = EConst ext STF32 0.0 +zeroScalarConst STF64 = EConst ext STF64 0.0 +zeroScalarConst STI32 = EConst ext STI32 0 +zeroScalarConst STI64 = EConst ext STI64 0 + dfwdDN :: Ex env t -> Ex (DNE env) (DN t) dfwdDN = \case EVar _ t i -> EVar ext (dn t) (convIdx i) @@ -131,13 +137,8 @@ dfwdDN = \case (ESum1Inner ext (dfwdDN e)) EUnit _ e -> EUnit ext (dfwdDN e) EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) - EMaximum1Inner _ e -> - let STArr n (STScal t) = typeOf e - in scalTyCase t - -- TODO: do roughly the same as what CHAD does, but forward - (_ (dfwdDN e)) - _ - EMinimum1Inner _ e -> EMinimum1Inner ext (dfwdDN e) + EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e + EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) @@ -165,3 +166,27 @@ dfwdDN = \case where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + + deriv_extremum :: ScalIsNumeric t ~ True + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) + deriv_extremum extremum e = + let STArr (SS n) (STScal t) = typeOf e + t2 = STPair (STScal t) (STScal t) + ta2 = STArr (SS n) t2 + tIxN = tTup (sreplicate (SS n) tIx) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ + ezip (EVar ext (STArr n (STScal t)) IZ) + (ESum1Inner ext + {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} + (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ + ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ + ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext + (EFst ext (EVar ext t2 IZ)) + (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) + (EFst ext (EVar ext tIxN (IS IZ))))))) + (ESnd ext (EVar ext t2 (IS IZ))) + (zeroScalarConst t)))) + (EMaximum1Inner ext (dfwdDN e)) |