summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:34:16 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:34:16 +0100
commitcabd95c691e7bf0bf5adb4609e6df2a10b08856c (patch)
tree0a9216f89afd28614940dae940a86334b41e5572 /test
parente34869318cd37fa73c12291141a5fea29248aede (diff)
Run test primals with Compile (not all succeed yet)
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs13
1 files changed, 10 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs
index ec23eaf..2872029 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -11,6 +11,7 @@
module Main where
import Control.Monad.Trans.Class (lift)
+import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
@@ -30,6 +31,7 @@ import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import CHAD.Types.ToTan
+import Compile
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
@@ -178,7 +180,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)
adTestGen :: forall env. KnownEnv env
=> TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
-adTestGen name expr envGenerator = testProperty name $ property $ do
+adTestGen name expr envGenerator =
+ withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do
let env = knownEnv @env
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
@@ -195,8 +198,12 @@ adTestGen name expr envGenerator = testProperty name $ property $ do
let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)
convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
- let outPrimal = interpretOpen False input expr
- (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
+ let outPrimalI = interpretOpen False input expr
+ primalfun <- liftIO $ getprimalfun
+ outPrimal <- liftIO $ primalfun input
+ diff outPrimal closeIsh outPrimalI
+
+ let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
(outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
scChad = envScalars env gradChad0
scChadS = envScalars env gradChadS