summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-21 23:20:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-21 23:20:57 +0200
commite7d7ac0fd8b81c1d6fae9ab7c1e4654133c631ea (patch)
tree4dc880e6956b42f0920382d772b49adc2a4ce556 /src
parent246439502b78c4a8fcc27ab3296c67471a2b239d (diff)
Tests
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs4
-rw-r--r--src/AST/Pretty.hs2
-rw-r--r--src/CHAD.hs4
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs2
-rw-r--r--src/Interpreter/Rep.hs13
-rw-r--r--src/Language.hs3
7 files changed, 28 insertions, 2 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 5dab62f..94c8537 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -150,6 +150,8 @@ data SOp a t where
OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
ONot :: SOp (TScal TBool) (TScal TBool)
OIf :: SOp (TScal TBool) (TEither TNil TNil)
+ ORound64 :: SOp (TScal TF64) (TScal TI64)
+ OToFl64 :: SOp (TScal TI64) (TScal TF64)
deriving instance Show (SOp a t)
opt2 :: SOp a t -> STy t
@@ -162,6 +164,8 @@ opt2 = \case
OEq _ -> STScal STBool
ONot -> STScal STBool
OIf -> STEither STNil STNil
+ ORound64 -> STScal STI64
+ OToFl64 -> STScal STF64
typeOf :: Expr x env t -> STy t
typeOf = \case
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 7f60db1..8f1fe67 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -243,3 +243,5 @@ operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
operator OIf = (Prefix, "ifB")
+operator ORound64 = (Prefix, "round")
+operator OToFl64 = (Prefix, "toFl64")
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 55d94b1..d05e77f 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -546,6 +546,8 @@ d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
d1op OIf e = EOp ext OIf e
+d1op ORound64 e = EOp ext ORound64 e
+d1op OToFl64 e = EOp ext OToFl64 e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -563,6 +565,8 @@ d2op op = case op of
OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
ONot -> Linear $ \_ -> ENil ext
OIf -> Linear $ \_ -> ENil ext
+ ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ OToFl64 -> Linear $ \_ -> ENil ext
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index f9239e9..3e45ce7 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -100,6 +100,8 @@ dop = \case
(EOp ext (OEq t))
ONot -> EOp ext ONot
OIf -> EOp ext OIf
+ ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
+ OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 4d1358f..8ce1b0e 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -125,6 +125,8 @@ interpretOp op arg = case op of
OEq st -> numericIsNum st $ uncurry (==) arg
ONot -> not arg
OIf -> if arg then Left () else Right ()
+ ORound64 -> round arg
+ OToFl64 -> fromIntegral arg
zeroD2 :: STy t -> Rep (D2 t)
zeroD2 typ = case typ of
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index adb4eba..ed307c0 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -39,5 +39,16 @@ type family RepAcDense t where
-- RepAcDense (TScal sty) = ScalRep sty
-- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators")
-newtype Value t = Value (Rep t)
+newtype Value t = Value { unValue :: Rep t }
+liftV :: (Rep a -> Rep b) -> Value a -> Value b
+liftV f (Value x) = Value (f x)
+
+liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
+liftV2 f (Value x) (Value y) = Value (f x y)
+
+vPair :: Value a -> Value b -> Value (TPair a b)
+vPair = liftV2 (,)
+
+vUnpair :: Value (TPair a b) -> (Value a, Value b)
+vUnpair (Value (x, y)) = (Value x, Value y)
diff --git a/src/Language.hs b/src/Language.hs
index cdc6d6b..80de713 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE TypeOperators #-}
module Language (
@@ -22,7 +23,7 @@ infixr 0 :->
body :: NExpr env t -> NFun env env t
body = NBody
-lambda :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
+lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam