summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 11:34:04 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 11:34:04 +0100
commit2b1562d33bb9496aa449ef9d52735af0ec61c15c (patch)
tree77aa4f3195b1493828e3c82c9bae4b419ba27c64 /src/CHAD.hs
parent992249ebf159ba3783a9345430013e52294c26aa (diff)
Some more primitive operators
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs11
1 files changed, 11 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a08fe80..fb6f5e3 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -259,6 +259,9 @@ d1op OOr e = EOp ext OOr e
d1op OIf e = EOp ext OIf e
d1op ORound64 e = EOp ext ORound64 e
d1op OToFl64 e = EOp ext OToFl64 e
+d1op (ORecip t) e = EOp ext (ORecip t) e
+d1op (OExp t) e = EOp ext (OExp t) e
+d1op (OLog t) e = EOp ext (OLog t) e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -280,6 +283,9 @@ d2op op = case op of
OIf -> Linear $ \_ -> ENil ext
ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
OToFl64 -> Linear $ \_ -> ENil ext
+ ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
+ OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
+ OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
@@ -301,6 +307,11 @@ d2op op = case op of
STF64 -> float
STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ floatingD2 :: ScalIsFloating a ~ True
+ => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
+ floatingD2 STF32 k = k
+ floatingD2 STF64 k = k
+
sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)