diff options
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs index 2acc9f8..d586973 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -11,6 +11,7 @@ {-# LANGUAGE UndecidableInstances #-} module Main where +import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor @@ -352,7 +353,10 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS dtermSChadSUSP = pruneExpr env dtermSChadSUS in - withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> + withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS + when (not (null output)) $ + outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>" + return fun) $ \fwdartifactC -> withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -396,7 +400,7 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) + annotate (ppExpr env dtermSChadSUSP) diff outChad0 closeIsh outPrimal diff outChadS closeIsh outPrimal diff outChadSUS closeIsh outPrimal @@ -416,7 +420,11 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +withCompiled env expr = withResource' $ do + (fun, output) <- compile env expr + when (not (null output)) $ + outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) gen_gmm = do @@ -562,6 +570,13 @@ tests_Compile = testGroup "Compile" nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree |
