summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 11:34:04 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 11:34:04 +0100
commit2b1562d33bb9496aa449ef9d52735af0ec61c15c (patch)
tree77aa4f3195b1493828e3c82c9bae4b419ba27c64
parent992249ebf159ba3783a9345430013e52294c26aa (diff)
Some more primitive operators
-rw-r--r--src/AST.hs6
-rw-r--r--src/AST/Pretty.hs3
-rw-r--r--src/AST/Types.hs7
-rw-r--r--src/CHAD.hs11
-rw-r--r--src/ForwardAD/DualNumbers.hs22
-rw-r--r--src/Interpreter.hs7
6 files changed, 54 insertions, 2 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 28c5b37..e3da634 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -169,6 +169,9 @@ data SOp a t where
OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right
ORound64 :: SOp (TScal TF64) (TScal TI64)
OToFl64 :: SOp (TScal TI64) (TScal TF64)
+ ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
deriving instance Show (SOp a t)
opt2 :: SOp a t -> STy t
@@ -185,6 +188,9 @@ opt2 = \case
OIf -> STEither STNil STNil
ORound64 -> STScal STI64
OToFl64 -> STScal STF64
+ ORecip t -> STScal t
+ OExp t -> STScal t
+ OLog t -> STScal t
typeOf :: Expr x env t -> STy t
typeOf = \case
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 4d9aeec..63742ad 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -275,6 +275,9 @@ operator OOr = (Infix, "||")
operator OIf = (Prefix, "ifB")
operator ORound64 = (Prefix, "round")
operator OToFl64 = (Prefix, "toFl64")
+operator ORecip{} = (Prefix, "recip")
+operator OExp{} = (Prefix, "exp")
+operator OLog{} = (Prefix, "log")
ppTy :: Int -> STy t -> String
ppTy d ty = ppTys d ty ""
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index a3e5080..5688277 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -93,3 +93,10 @@ type family ScalIsNumeric t where
ScalIsNumeric TF32 = True
ScalIsNumeric TF64 = True
ScalIsNumeric TBool = False
+
+type family ScalIsFloating t where
+ ScalIsFloating TI32 = False
+ ScalIsFloating TI64 = False
+ ScalIsFloating TF32 = True
+ ScalIsFloating TF64 = True
+ ScalIsFloating TBool = False
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a08fe80..fb6f5e3 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -259,6 +259,9 @@ d1op OOr e = EOp ext OOr e
d1op OIf e = EOp ext OIf e
d1op ORound64 e = EOp ext ORound64 e
d1op OToFl64 e = EOp ext OToFl64 e
+d1op (ORecip t) e = EOp ext (ORecip t) e
+d1op (OExp t) e = EOp ext (OExp t) e
+d1op (OLog t) e = EOp ext (OLog t) e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -280,6 +283,9 @@ d2op op = case op of
OIf -> Linear $ \_ -> ENil ext
ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
OToFl64 -> Linear $ \_ -> ENil ext
+ ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
+ OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
+ OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
@@ -301,6 +307,11 @@ d2op op = case op of
STF64 -> float
STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ floatingD2 :: ScalIsFloating a ~ True
+ => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
+ floatingD2 STF32 k = k
+ floatingD2 STF64 k = k
+
sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 7d47e6d..9ed04bb 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -41,6 +41,12 @@ scalTyCase STI32 _ k2 = k2
scalTyCase STI64 _ k2 = k2
scalTyCase STBool _ k2 = k2
+floatingDual :: ScalIsFloating t ~ True
+ => SScalTy t
+ -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r
+floatingDual STF32 k = k
+floatingDual STF64 k = k
+
-- | Argument does not need to be duplicable.
dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
dop = \case
@@ -68,6 +74,14 @@ dop = \case
OIf -> EOp ext OIf
ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
+ ORecip t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (recip' t x)
+ (mul t (neg t (recip' t (mul t x x))) dx))
+ OExp t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (EOp ext (OExp t) x) (EOp ext (OExp t) dx))
+ OLog t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (EOp ext (OLog t) x)
+ (mul t (recip' t x) dx))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
@@ -81,6 +95,10 @@ dop = \case
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
neg t = EOp ext (ONeg t)
+ recip' :: ScalIsFloating t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
+ recip' t = EOp ext (ORecip t)
+
unFloat :: DN a ~ TPair a a
=> (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
-> Ex env (DN a) -> Ex env (DN b)
@@ -99,10 +117,10 @@ dop = \case
(EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
-zeroScalarConst STF32 = EConst ext STF32 0.0
-zeroScalarConst STF64 = EConst ext STF64 0.0
zeroScalarConst STI32 = EConst ext STI32 0
zeroScalarConst STI64 = EConst ext STI64 0
+zeroScalarConst STF32 = EConst ext STF32 0.0
+zeroScalarConst STF64 = EConst ext STF64 0.0
dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 3c1aad0..576b0d9 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -170,6 +170,9 @@ interpretOp op arg = case op of
OIf -> if arg then Left () else Right ()
ORound64 -> round arg
OToFl64 -> fromIntegral arg
+ ORecip st -> floatingIsFractional st $ recip arg
+ OExp st -> floatingIsFractional st $ exp arg
+ OLog st -> floatingIsFractional st $ log arg
where
styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
styIsEq STI32 = id
@@ -523,6 +526,10 @@ numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id
+floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True) => r) -> r
+floatingIsFractional STF32 = id
+floatingIsFractional STF64 = id
+
unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _ SZ _ = nil