From cabd95c691e7bf0bf5adb4609e6df2a10b08856c Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Mar 2025 17:34:16 +0100 Subject: Run test primals with Compile (not all succeed yet) --- test/Main.hs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'test') diff --git a/test/Main.hs b/test/Main.hs index ec23eaf..2872029 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -11,6 +11,7 @@ module Main where import Control.Monad.Trans.Class (lift) +import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) @@ -30,6 +31,7 @@ import AST.UnMonoid import CHAD.Top import CHAD.Types import CHAD.Types.ToTan +import Compile import qualified Example import qualified Example.GMM as Example import ForwardAD @@ -178,7 +180,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) adTestGen :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree -adTestGen name expr envGenerator = testProperty name $ property $ do +adTestGen name expr envGenerator = + withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do let env = knownEnv @env annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -195,8 +198,12 @@ adTestGen name expr envGenerator = testProperty name $ property $ do let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env) convGrad = toTanE env input . unTup vUnpair (d2e env) . Value - let outPrimal = interpretOpen False input expr - (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 + let outPrimalI = interpretOpen False input expr + primalfun <- liftIO $ getprimalfun + outPrimal <- liftIO $ primalfun input + diff outPrimal closeIsh outPrimalI + + let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS scChad = envScalars env gradChad0 scChadS = envScalars env gradChadS -- cgit v1.2.3-70-g09d2