summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-27 00:01:15 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-27 00:01:15 +0100
commit75141f1c1f97fef563df2be6e512e568f922cb45 (patch)
treee621c50b2b43fb4050ef074365ef29016869dd35
parent6e85d5b2aee0cf2c089538e74261f1d88d6b1b71 (diff)
test: type R = TScal TF64
-rw-r--r--test/Main.hs41
1 files changed, 22 insertions, 19 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 014ad43..933629c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -41,6 +41,9 @@ import Language
import Simplify
+type R = TScal TF64
+
+
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
@@ -51,19 +54,19 @@ simplifyIters iters env | Dict <- envKnown env =
SimplFix -> simplifyFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
+gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
+gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' simplIters env term input =
second (second (toTanE env input)) $
gradientByCHAD simplIters env term input
-gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
@@ -227,20 +230,20 @@ compileTest name expr =
let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
diff (TypedValue t resI) cmp (TypedValue t resC)
-adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
+adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree
adTest name = adTestCon name (const True)
-adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree
+adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree
adTestCon name constr term =
let env = knownEnv
in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
- => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree
+ => TestName -> TemplateE env -> Ex env R -> TestTree
adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
- => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
+ => TestName -> Ex env R -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
let env = knownEnv @env
exprS = simplifyFix expr
@@ -252,7 +255,7 @@ adTestGen name expr envGenerator =
,adTestGenChad env envGenerator expr exprS primalSfun]
adTestGenPrimal :: SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R -> Ex env R
-> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
-> TestTree
adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
@@ -268,7 +271,7 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
adTestGenFwd :: SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64)
+ -> Ex env R
-> TestTree
adTestGenFwd env envGenerator exprS =
withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
@@ -281,7 +284,7 @@ adTestGenFwd env envGenerator exprS =
diff outDNI2 (closeIsh' 1e-8) outDNC2
adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R -> Ex env R
-> (SList Value env -> IO Double)
-> TestTree
adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
@@ -341,18 +344,18 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
-term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_build1_sum :: Ex '[TArr N1 R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
-term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
+term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
-term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_sparse :: Ex '[TArr N1 R] R
term_sparse = fromNamed $ lambda #inp $ body $
let_ #n (snd_ (shape #inp)) $
let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
@@ -361,7 +364,7 @@ term_sparse = fromNamed $ lambda #inp $ body $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
-term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_regression_simpl1 :: Ex '[TArr N1 R] R
term_regression_simpl1 = fromNamed $ lambda #q $ body $
idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
let_ #j (snd_ #idx) $
@@ -369,7 +372,7 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $
(#q ! pair nil 0)
(if_ (#j .== #j) 1.0 2.0)
-term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
+term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R
term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
let_ #hei (snd_ (fst_ (shape #mat))) $
@@ -414,7 +417,7 @@ tests_AD = testGroup "AD"
,adTest "pairs" term_pairs
- ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $
+ ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
@@ -428,14 +431,14 @@ tests_AD = testGroup "AD"
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
- fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ fromNamed $ lambda @(TArr N2 R) #x $ body $
idx0 $ sum1i $ maximum1i #x
,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
- fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ fromNamed $ lambda @(TArr N2 R) #x $ body $
idx0 $ sum1i $ minimum1i #x
- ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
+ ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42