diff options
| -rw-r--r-- | src/Compile.hs | 2 | ||||
| -rw-r--r-- | src/Compile/Exec.hs | 14 | ||||
| -rw-r--r-- | test/Main.hs | 13 | 
3 files changed, 22 insertions, 7 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 4c07f3a..d31c531 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -48,7 +48,7 @@ compile :: SList STy env -> Ex env t          -> IO (SList Value env -> IO (Rep t))  compile = \env expr -> do    let source = compileToString env expr -  hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" +  -- hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"    lib <- buildKernel source ["kernel"]    let arg_metrics = reverse (unSList metricsSTy env) diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index 487ed8a..83ce4ff 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -6,6 +6,7 @@ module Compile.Exec (    callKernelFun,  ) where +import Control.Monad (when)  import Data.IORef  import qualified Data.Map.Strict as Map  import Data.Map.Strict (Map) @@ -20,6 +21,9 @@ import System.Posix.Temp (mkdtemp)  import System.Process (readProcess) +debug :: Bool +debug = False +  -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)  data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ())))) @@ -29,11 +33,15 @@ buildKernel csource funnames = do    path <- mkdtemp template    let outso = path ++ "/out.so" -  let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-", "-Wall", "-Wextra", "-Wno-unused-parameter"] +  let args = ["-O3", "-march=native" +             ,"-shared", "-fPIC" +             ,"-std=c99", "-x", "c" +             ,"-o", outso, "-" +             ,"-Wall", "-Wextra", "-Wno-unused-variable", "-Wno-unused-parameter"]    _ <- readProcess "gcc" args csource    numLoaded <- atomicModifyIORef' numLoadedCounter (\n -> (n+1, n+1)) -  hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)" +  when debug $ hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)"    dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL]    removeDirectoryRecursive path  -- we keep a reference anyway because we have the file open now @@ -41,7 +49,7 @@ buildKernel csource funnames = do    ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames]    ref <- newIORef ptrs    _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) -                           hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" +                           when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"                             dlclose dl)    return (KernelLib ref) diff --git a/test/Main.hs b/test/Main.hs index ec23eaf..2872029 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -11,6 +11,7 @@  module Main where  import Control.Monad.Trans.Class (lift) +import Control.Monad.IO.Class (liftIO)  import Control.Monad.Trans.State  import Data.Bifunctor  import Data.Int (Int64) @@ -30,6 +31,7 @@ import AST.UnMonoid  import CHAD.Top  import CHAD.Types  import CHAD.Types.ToTan +import Compile  import qualified Example  import qualified Example.GMM as Example  import ForwardAD @@ -178,7 +180,8 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)  adTestGen :: forall env. KnownEnv env            => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree -adTestGen name expr envGenerator = testProperty name $ property $ do +adTestGen name expr envGenerator = +  withResource (compile knownEnv expr) (\_ -> pure ()) $ \getprimalfun -> testProperty name $ property $ do    let env = knownEnv @env    annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -195,8 +198,12 @@ adTestGen name expr envGenerator = testProperty name $ property $ do    let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)        convGrad = toTanE env input . unTup vUnpair (d2e env) . Value -  let outPrimal = interpretOpen False input expr -      (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 +  let outPrimalI = interpretOpen False input expr +  primalfun <- liftIO $ getprimalfun +  outPrimal <- liftIO $ primalfun input +  diff outPrimal closeIsh outPrimalI + +  let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0        (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS        scChad = envScalars env gradChad0        scChadS = envScalars env gradChadS | 
