From cabd95c691e7bf0bf5adb4609e6df2a10b08856c Mon Sep 17 00:00:00 2001
From: Tom Smeding <t.j.smeding@uu.nl>
Date: Mon, 3 Mar 2025 17:34:16 +0100
Subject: Run test primals with Compile (not all succeed yet)

---
 test/Main.hs | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

(limited to 'test/Main.hs')

diff --git a/test/Main.hs b/test/Main.hs
index ec23eaf..2872029 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -11,6 +11,7 @@
 module Main where
 
 import Control.Monad.Trans.Class (lift)
+import Control.Monad.IO.Class (liftIO)
 import Control.Monad.Trans.State
 import Data.Bifunctor
 import Data.Int (Int64)
@@ -30,6 +31,7 @@ import AST.UnMonoid
 import CHAD.Top
 import CHAD.Types
 import CHAD.Types.ToTan
+import Compile
 import qualified Example
 import qualified Example.GMM as Example
 import ForwardAD
@@ -178,7 +180,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)
 
 adTestGen :: forall env. KnownEnv env
           => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
-adTestGen name expr envGenerator = testProperty name $ property $ do
+adTestGen name expr envGenerator =
+  withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do
   let env = knownEnv @env
 
   annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
@@ -195,8 +198,12 @@ adTestGen name expr envGenerator = testProperty name $ property $ do
   let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)
       convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
 
-  let outPrimal = interpretOpen False input expr
-      (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
+  let outPrimalI = interpretOpen False input expr
+  primalfun <- liftIO $ getprimalfun
+  outPrimal <- liftIO $ primalfun input
+  diff outPrimal closeIsh outPrimalI
+
+  let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
       (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
       scChad = envScalars env gradChad0
       scChadS = envScalars env gradChadS
-- 
cgit v1.2.3-70-g09d2