diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 11:34:04 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 11:34:04 +0100 |
commit | 2b1562d33bb9496aa449ef9d52735af0ec61c15c (patch) | |
tree | 77aa4f3195b1493828e3c82c9bae4b419ba27c64 | |
parent | 992249ebf159ba3783a9345430013e52294c26aa (diff) |
Some more primitive operators
-rw-r--r-- | src/AST.hs | 6 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 3 | ||||
-rw-r--r-- | src/AST/Types.hs | 7 | ||||
-rw-r--r-- | src/CHAD.hs | 11 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 22 | ||||
-rw-r--r-- | src/Interpreter.hs | 7 |
6 files changed, 54 insertions, 2 deletions
@@ -169,6 +169,9 @@ data SOp a t where OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right ORound64 :: SOp (TScal TF64) (TScal TI64) OToFl64 :: SOp (TScal TI64) (TScal TF64) + ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) deriving instance Show (SOp a t) opt2 :: SOp a t -> STy t @@ -185,6 +188,9 @@ opt2 = \case OIf -> STEither STNil STNil ORound64 -> STScal STI64 OToFl64 -> STScal STF64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t typeOf :: Expr x env t -> STy t typeOf = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 4d9aeec..63742ad 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -275,6 +275,9 @@ operator OOr = (Infix, "||") operator OIf = (Prefix, "ifB") operator ORound64 = (Prefix, "round") operator OToFl64 = (Prefix, "toFl64") +operator ORecip{} = (Prefix, "recip") +operator OExp{} = (Prefix, "exp") +operator OLog{} = (Prefix, "log") ppTy :: Int -> STy t -> String ppTy d ty = ppTys d ty "" diff --git a/src/AST/Types.hs b/src/AST/Types.hs index a3e5080..5688277 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -93,3 +93,10 @@ type family ScalIsNumeric t where ScalIsNumeric TF32 = True ScalIsNumeric TF64 = True ScalIsNumeric TBool = False + +type family ScalIsFloating t where + ScalIsFloating TI32 = False + ScalIsFloating TI64 = False + ScalIsFloating TF32 = True + ScalIsFloating TF64 = True + ScalIsFloating TBool = False diff --git a/src/CHAD.hs b/src/CHAD.hs index a08fe80..fb6f5e3 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -259,6 +259,9 @@ d1op OOr e = EOp ext OOr e d1op OIf e = EOp ext OIf e d1op ORound64 e = EOp ext ORound64 e d1op OToFl64 e = EOp ext OToFl64 e +d1op (ORecip t) e = EOp ext (ORecip t) e +d1op (OExp t) e = EOp ext (OExp t) e +d1op (OLog t) e = EOp ext (OLog t) e -- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) @@ -280,6 +283,9 @@ d2op op = case op of OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 OToFl64 -> Linear $ \_ -> ENil ext + ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) + OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) + OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) @@ -301,6 +307,11 @@ d2op op = case op of STF64 -> float STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + floatingD2 :: ScalIsFloating a ~ True + => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r + floatingD2 STF32 k = k + floatingD2 STF64 k = k + sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 7d47e6d..9ed04bb 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -41,6 +41,12 @@ scalTyCase STI32 _ k2 = k2 scalTyCase STI64 _ k2 = k2 scalTyCase STBool _ k2 = k2 +floatingDual :: ScalIsFloating t ~ True + => SScalTy t + -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r +floatingDual STF32 k = k +floatingDual STF64 k = k + -- | Argument does not need to be duplicable. dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) dop = \case @@ -68,6 +74,14 @@ dop = \case OIf -> EOp ext OIf ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) + ORecip t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (recip' t x) + (mul t (neg t (recip' t (mul t x x))) dx)) + OExp t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OExp t) x) (EOp ext (OExp t) dx)) + OLog t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OLog t) x) + (mul t (recip' t x) dx)) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) @@ -81,6 +95,10 @@ dop = \case => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) neg t = EOp ext (ONeg t) + recip' :: ScalIsFloating t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + recip' t = EOp ext (ORecip t) + unFloat :: DN a ~ TPair a a => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) -> Ex env (DN a) -> Ex env (DN b) @@ -99,10 +117,10 @@ dop = \case (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) -zeroScalarConst STF32 = EConst ext STF32 0.0 -zeroScalarConst STF64 = EConst ext STF64 0.0 zeroScalarConst STI32 = EConst ext STI32 0 zeroScalarConst STI64 = EConst ext STI64 0 +zeroScalarConst STF32 = EConst ext STF32 0.0 +zeroScalarConst STF64 = EConst ext STF64 0.0 dfwdDN :: Ex env t -> Ex (DNE env) (DN t) dfwdDN = \case diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3c1aad0..576b0d9 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -170,6 +170,9 @@ interpretOp op arg = case op of OIf -> if arg then Left () else Right () ORound64 -> round arg OToFl64 -> fromIntegral arg + ORecip st -> floatingIsFractional st $ recip arg + OExp st -> floatingIsFractional st $ exp arg + OLog st -> floatingIsFractional st $ log arg where styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r styIsEq STI32 = id @@ -523,6 +526,10 @@ numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id +floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True) => r) -> r +floatingIsFractional STF32 = id +floatingIsFractional STF64 = id + unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n unTupRepIdx nil _ SZ _ = nil |