summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:34:16 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:34:16 +0100
commitcabd95c691e7bf0bf5adb4609e6df2a10b08856c (patch)
tree0a9216f89afd28614940dae940a86334b41e5572
parente34869318cd37fa73c12291141a5fea29248aede (diff)
Run test primals with Compile (not all succeed yet)
-rw-r--r--src/Compile.hs2
-rw-r--r--src/Compile/Exec.hs14
-rw-r--r--test/Main.hs13
3 files changed, 22 insertions, 7 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 4c07f3a..d31c531 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -48,7 +48,7 @@ compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
let source = compileToString env expr
- hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
+ -- hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
lib <- buildKernel source ["kernel"]
let arg_metrics = reverse (unSList metricsSTy env)
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index 487ed8a..83ce4ff 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -6,6 +6,7 @@ module Compile.Exec (
callKernelFun,
) where
+import Control.Monad (when)
import Data.IORef
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
@@ -20,6 +21,9 @@ import System.Posix.Temp (mkdtemp)
import System.Process (readProcess)
+debug :: Bool
+debug = False
+
-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)
data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ()))))
@@ -29,11 +33,15 @@ buildKernel csource funnames = do
path <- mkdtemp template
let outso = path ++ "/out.so"
- let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-", "-Wall", "-Wextra", "-Wno-unused-parameter"]
+ let args = ["-O3", "-march=native"
+ ,"-shared", "-fPIC"
+ ,"-std=c99", "-x", "c"
+ ,"-o", outso, "-"
+ ,"-Wall", "-Wextra", "-Wno-unused-variable", "-Wno-unused-parameter"]
_ <- readProcess "gcc" args csource
numLoaded <- atomicModifyIORef' numLoadedCounter (\n -> (n+1, n+1))
- hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)"
+ when debug $ hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)"
dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL]
removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now
@@ -41,7 +49,7 @@ buildKernel csource funnames = do
ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames]
ref <- newIORef ptrs
_ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1))
- hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"
+ when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"
dlclose dl)
return (KernelLib ref)
diff --git a/test/Main.hs b/test/Main.hs
index ec23eaf..2872029 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -11,6 +11,7 @@
module Main where
import Control.Monad.Trans.Class (lift)
+import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
@@ -30,6 +31,7 @@ import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import CHAD.Types.ToTan
+import Compile
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
@@ -178,7 +180,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)
adTestGen :: forall env. KnownEnv env
=> TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
-adTestGen name expr envGenerator = testProperty name $ property $ do
+adTestGen name expr envGenerator =
+ withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do
let env = knownEnv @env
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
@@ -195,8 +198,12 @@ adTestGen name expr envGenerator = testProperty name $ property $ do
let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)
convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
- let outPrimal = interpretOpen False input expr
- (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
+ let outPrimalI = interpretOpen False input expr
+ primalfun <- liftIO $ getprimalfun
+ outPrimal <- liftIO $ primalfun input
+ diff outPrimal closeIsh outPrimalI
+
+ let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
(outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
scChad = envScalars env gradChad0
scChadS = envScalars env gradChadS