aboutsummaryrefslogtreecommitdiff
path: root/src/ForwardAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD')
-rw-r--r--src/ForwardAD/DualNumbers.hs11
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs2
2 files changed, 12 insertions, 1 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index aa35a5b..3ab08af 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -86,6 +86,9 @@ dop = \case
OIDiv t -> scalTyCase t
(case t of {})
(EOp ext (OIDiv t))
+ OMod t -> scalTyCase t
+ (case t of {})
+ (EOp ext (OMod t))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
@@ -140,13 +143,17 @@ dfwdDN = \case
ENothing _ t -> ENothing ext (dn t)
EJust _ e -> EJust ext (dfwdDN e)
EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2)
+ ELInl _ t e -> ELInl ext (dn t) (dfwdDN e)
+ ELInr _ t e -> ELInr ext (dn t) (dfwdDN e)
+ ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c)
EConstArr _ n t x -> scalTyCase t
(emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
(EConstArr ext n t x))
(EConstArr ext n t x)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
- EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)
ESum1Inner _ e ->
let STArr n (STScal t) = typeOf e
pairty = (STPair (STScal t) (STScal t))
@@ -178,10 +185,12 @@ dfwdDN = \case
ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
+ ERecompute _ e -> dfwdDN e
EError _ t s -> EError ext (dn t) s
EWith{} -> err_accum
EAccum{} -> err_accum
+ EDeepZero{} -> err_monoid
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
index fba92d0..dcacf5f 100644
--- a/src/ForwardAD/DualNumbers/Types.hs
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -12,6 +12,7 @@ type family DN t where
DN TNil = TNil
DN (TPair a b) = TPair (DN a) (DN b)
DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TLEither a b) = TLEither (DN a) (DN b)
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
@@ -31,6 +32,7 @@ dn :: STy t -> STy (DN t)
dn STNil = STNil
dn (STPair a b) = STPair (dn a) (dn b)
dn (STEither a b) = STEither (dn a) (dn b)
+dn (STLEither a b) = STLEither (dn a) (dn b)
dn (STMaybe t) = STMaybe (dn t)
dn (STArr n t) = STArr n (dn t)
dn (STScal t) = case t of