diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-03 17:41:43 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-03 17:41:43 +0100 |
commit | c49deb3baa4baaca78a9d301d4dd17db84fb6c5b (patch) | |
tree | f56144d090db20de6d3624e74d4310aa7f939799 | |
parent | cabd95c691e7bf0bf5adb4609e6df2a10b08856c (diff) |
test: Little cleanup
-rw-r--r-- | test/Main.hs | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs index 2872029..83eaa83 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -181,7 +181,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) adTestGen :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = - withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do + withCompiled expr $ \getprimalfun -> + testProperty name $ property $ do let env = knownEnv @env annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -199,8 +200,7 @@ adTestGen name expr envGenerator = convGrad = toTanE env input . unTup vUnpair (d2e env) . Value let outPrimalI = interpretOpen False input expr - primalfun <- liftIO $ getprimalfun - outPrimal <- liftIO $ primalfun input + outPrimal <- liftIO $ getprimalfun >>= ($ input) diff outPrimal closeIsh outPrimalI let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 @@ -222,6 +222,9 @@ adTestGen name expr envGenerator = envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs +withCompiled :: KnownEnv env => Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree +withCompiled expr = withResource (compile knownEnv expr) (\_ -> pure ()) + term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ |