diff options
Diffstat (limited to 'src/ForwardAD')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 11 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers/Types.hs | 2 |
2 files changed, 12 insertions, 1 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index aa35a5b..3ab08af 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -86,6 +86,9 @@ dop = \case OIDiv t -> scalTyCase t (case t of {}) (EOp ext (OIDiv t)) + OMod t -> scalTyCase t + (case t of {}) + (EOp ext (OMod t)) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) @@ -140,13 +143,17 @@ dfwdDN = \case ENothing _ t -> ENothing ext (dn t) EJust _ e -> EJust ext (dfwdDN e) EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) + ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) + ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) + ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) EConstArr _ n t x -> scalTyCase t (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) - EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e pairty = (STPair (STScal t) (STScal t)) @@ -178,10 +185,12 @@ dfwdDN = \case ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) + ERecompute _ e -> dfwdDN e EError _ t s -> EError ext (dn t) s EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs index fba92d0..dcacf5f 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -12,6 +12,7 @@ type family DN t where DN TNil = TNil DN (TPair a b) = TPair (DN a) (DN b) DN (TEither a b) = TEither (DN a) (DN b) + DN (TLEither a b) = TLEither (DN a) (DN b) DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t @@ -31,6 +32,7 @@ dn :: STy t -> STy (DN t) dn STNil = STNil dn (STPair a b) = STPair (dn a) (dn b) dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b) dn (STMaybe t) = STMaybe (dn t) dn (STArr n t) = STArr n (dn t) dn (STScal t) = case t of |