diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-21 23:20:57 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-21 23:20:57 +0200 |
commit | e7d7ac0fd8b81c1d6fae9ab7c1e4654133c631ea (patch) | |
tree | 4dc880e6956b42f0920382d772b49adc2a4ce556 /src | |
parent | 246439502b78c4a8fcc27ab3296c67471a2b239d (diff) |
Tests
Diffstat (limited to 'src')
-rw-r--r-- | src/AST.hs | 4 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 4 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 2 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 13 | ||||
-rw-r--r-- | src/Language.hs | 3 |
7 files changed, 28 insertions, 2 deletions
@@ -150,6 +150,8 @@ data SOp a t where OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) + ORound64 :: SOp (TScal TF64) (TScal TI64) + OToFl64 :: SOp (TScal TI64) (TScal TF64) deriving instance Show (SOp a t) opt2 :: SOp a t -> STy t @@ -162,6 +164,8 @@ opt2 = \case OEq _ -> STScal STBool ONot -> STScal STBool OIf -> STEither STNil STNil + ORound64 -> STScal STI64 + OToFl64 -> STScal STF64 typeOf :: Expr x env t -> STy t typeOf = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 7f60db1..8f1fe67 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -243,3 +243,5 @@ operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") operator OIf = (Prefix, "ifB") +operator ORound64 = (Prefix, "round") +operator OToFl64 = (Prefix, "toFl64") diff --git a/src/CHAD.hs b/src/CHAD.hs index 55d94b1..d05e77f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -546,6 +546,8 @@ d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e -- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) @@ -563,6 +565,8 @@ d2op op = case op of OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) ONot -> Linear $ \_ -> ENil ext OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + OToFl64 -> Linear $ \_ -> ENil ext where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index f9239e9..3e45ce7 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -100,6 +100,8 @@ dop = \case (EOp ext (OEq t)) ONot -> EOp ext ONot OIf -> EOp ext OIf + ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) + OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 4d1358f..8ce1b0e 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -125,6 +125,8 @@ interpretOp op arg = case op of OEq st -> numericIsNum st $ uncurry (==) arg ONot -> not arg OIf -> if arg then Left () else Right () + ORound64 -> round arg + OToFl64 -> fromIntegral arg zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index adb4eba..ed307c0 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -39,5 +39,16 @@ type family RepAcDense t where -- RepAcDense (TScal sty) = ScalRep sty -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") -newtype Value t = Value (Rep t) +newtype Value t = Value { unValue :: Rep t } +liftV :: (Rep a -> Rep b) -> Value a -> Value b +liftV f (Value x) = Value (f x) + +liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c +liftV2 f (Value x) (Value y) = Value (f x y) + +vPair :: Value a -> Value b -> Value (TPair a b) +vPair = liftV2 (,) + +vUnpair :: Value (TPair a b) -> (Value a, Value b) +vUnpair (Value (x, y)) = (Value x, Value y) diff --git a/src/Language.hs b/src/Language.hs index cdc6d6b..80de713 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeOperators #-} module Language ( @@ -22,7 +23,7 @@ infixr 0 :-> body :: NExpr env t -> NFun env env t body = NBody -lambda :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t lambda = NLam |