summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs13
1 files changed, 6 insertions, 7 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 117a864..5fa1d46 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -10,7 +10,6 @@
{-# LANGUAGE UndecidableInstances #-}
module Main where
-import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State
import Data.Bifunctor
@@ -231,7 +230,7 @@ compileTestGen name expr envGenerator =
testProperty name $ property $ do
input <- forAllWith (showEnv env) envGenerator
let resI = interpretOpen False env input expr
- resC <- liftIO $ fun input
+ resC <- evalIO $ fun input
let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
diff (TypedValue t resI) cmp (TypedValue t resC)
@@ -268,11 +267,11 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
input <- forAllWith (showEnv env) envGenerator
let outPrimalI = interpretOpen False env input expr
- outPrimalC <- liftIO $ primalfun input
+ outPrimalC <- evalIO $ primalfun input
diff outPrimalI (closeIsh' 1e-8) outPrimalC
let outPrimalSI = interpretOpen False env input exprS
- outPrimalSC <- liftIO $ primalSfun input
+ outPrimalSC <- evalIO $ primalSfun input
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
adTestGenFwd :: SList STy env -> Gen (SList Value env)
@@ -284,7 +283,7 @@ adTestGenFwd env envGenerator exprS =
input <- forAllWith (showEnv env) envGenerator
dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS)
- (outDNC1, outDNC2) <- liftIO $ dnfun dinput
+ (outDNC1, outDNC2) <- evalIO $ dnfun dinput
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
@@ -308,7 +307,7 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
input <- forAllWith (showEnv env) envGenerator
- outPrimal <- liftIO $ primalSfun input
+ outPrimal <- evalIO $ primalSfun input
let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
unpackGrad = unTup vUnpair (d2e env) . Value
@@ -324,7 +323,7 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
scSChad = tanEScalars env $ toTanE env input gradSChad0
scSChadS = tanEScalars env $ toTanE env input gradSChadS
- (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input)
+ (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input)
let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))