diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 13 | 
1 files changed, 6 insertions, 7 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 117a864..5fa1d46 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -10,7 +10,6 @@  {-# LANGUAGE UndecidableInstances #-}  module Main where -import Control.Monad.IO.Class (liftIO)  import Control.Monad.Trans.Class (lift)  import Control.Monad.Trans.State  import Data.Bifunctor @@ -231,7 +230,7 @@ compileTestGen name expr envGenerator =       testProperty name $ property $ do         input <- forAllWith (showEnv env) envGenerator         let resI = interpretOpen False env input expr -       resC <- liftIO $ fun input +       resC <- evalIO $ fun input         let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y         diff (TypedValue t resI) cmp (TypedValue t resC) @@ -268,11 +267,11 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =      input <- forAllWith (showEnv env) envGenerator      let outPrimalI = interpretOpen False env input expr -    outPrimalC <- liftIO $ primalfun input +    outPrimalC <- evalIO $ primalfun input      diff outPrimalI (closeIsh' 1e-8) outPrimalC      let outPrimalSI = interpretOpen False env input exprS -    outPrimalSC <- liftIO $ primalSfun input +    outPrimalSC <- evalIO $ primalSfun input      diff outPrimalSI (closeIsh' 1e-8) outPrimalSC  adTestGenFwd :: SList STy env -> Gen (SList Value env) @@ -284,7 +283,7 @@ adTestGenFwd env envGenerator exprS =        input <- forAllWith (showEnv env) envGenerator        dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input        let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS) -      (outDNC1, outDNC2) <- liftIO $ dnfun dinput +      (outDNC1, outDNC2) <- evalIO $ dnfun dinput        diff outDNI1 (closeIsh' 1e-8) outDNC1        diff outDNI2 (closeIsh' 1e-8) outDNC2 @@ -308,7 +307,7 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =        diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))        input <- forAllWith (showEnv env) envGenerator -      outPrimal <- liftIO $ primalSfun input +      outPrimal <- evalIO $ primalSfun input        let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)            unpackGrad = unTup vUnpair (d2e env) . Value @@ -324,7 +323,7 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =            scSChad  = tanEScalars env $ toTanE env input gradSChad0            scSChadS = tanEScalars env $ toTanE env input gradSChadS -      (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) +      (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input)        let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS        -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) | 
