summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:41:43 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:41:43 +0100
commitc49deb3baa4baaca78a9d301d4dd17db84fb6c5b (patch)
treef56144d090db20de6d3624e74d4310aa7f939799
parentcabd95c691e7bf0bf5adb4609e6df2a10b08856c (diff)
test: Little cleanup
-rw-r--r--test/Main.hs9
1 files changed, 6 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 2872029..83eaa83 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -181,7 +181,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)
adTestGen :: forall env. KnownEnv env
=> TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
- withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do
+ withCompiled expr $ \getprimalfun ->
+ testProperty name $ property $ do
let env = knownEnv @env
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
@@ -199,8 +200,7 @@ adTestGen name expr envGenerator =
convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
let outPrimalI = interpretOpen False input expr
- primalfun <- liftIO $ getprimalfun
- outPrimal <- liftIO $ primalfun input
+ outPrimal <- liftIO $ getprimalfun >>= ($ input)
diff outPrimal closeIsh outPrimalI
let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
@@ -222,6 +222,9 @@ adTestGen name expr envGenerator =
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+withCompiled :: KnownEnv env => Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
+withCompiled expr = withResource (compile knownEnv expr) (\_ -> pure ())
+
term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $