diff options
-rw-r--r-- | test/Main.hs | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/test/Main.hs b/test/Main.hs index 117a864..5fa1d46 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -10,7 +10,6 @@ {-# LANGUAGE UndecidableInstances #-} module Main where -import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor @@ -231,7 +230,7 @@ compileTestGen name expr envGenerator = testProperty name $ property $ do input <- forAllWith (showEnv env) envGenerator let resI = interpretOpen False env input expr - resC <- liftIO $ fun input + resC <- evalIO $ fun input let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y diff (TypedValue t resI) cmp (TypedValue t resC) @@ -268,11 +267,11 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = input <- forAllWith (showEnv env) envGenerator let outPrimalI = interpretOpen False env input expr - outPrimalC <- liftIO $ primalfun input + outPrimalC <- evalIO $ primalfun input diff outPrimalI (closeIsh' 1e-8) outPrimalC let outPrimalSI = interpretOpen False env input exprS - outPrimalSC <- liftIO $ primalSfun input + outPrimalSC <- evalIO $ primalSfun input diff outPrimalSI (closeIsh' 1e-8) outPrimalSC adTestGenFwd :: SList STy env -> Gen (SList Value env) @@ -284,7 +283,7 @@ adTestGenFwd env envGenerator exprS = input <- forAllWith (showEnv env) envGenerator dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS) - (outDNC1, outDNC2) <- liftIO $ dnfun dinput + (outDNC1, outDNC2) <- evalIO $ dnfun dinput diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 @@ -308,7 +307,7 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) input <- forAllWith (showEnv env) envGenerator - outPrimal <- liftIO $ primalSfun input + outPrimal <- evalIO $ primalSfun input let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) unpackGrad = unTup vUnpair (d2e env) . Value @@ -324,7 +323,7 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = scSChad = tanEScalars env $ toTanE env input gradSChad0 scSChadS = tanEScalars env $ toTanE env input gradSChadS - (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) + (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) |