summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs18
1 files changed, 9 insertions, 9 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 20b4ef0..3a6bc71 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -57,7 +57,7 @@ simplifyIters iters env | Dict <- envKnown env =
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- (out, grad) = interpretOpen False input dterm
+ (out, grad) = interpretOpen False env input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
@@ -232,7 +232,7 @@ compileTestGen name expr envGenerator =
in withCompiled env expr $ \fun ->
testProperty name $ property $ do
input <- forAllWith (showEnv env) envGenerator
- let resI = interpretOpen False input expr
+ let resI = interpretOpen False env input expr
resC <- liftIO $ fun input
let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
diff (TypedValue t resI) cmp (TypedValue t resC)
@@ -269,11 +269,11 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
testProperty "compile primal" $ property $ do
input <- forAllWith (showEnv env) envGenerator
- let outPrimalI = interpretOpen False input expr
+ let outPrimalI = interpretOpen False env input expr
outPrimalC <- liftIO $ primalfun input
diff outPrimalI (closeIsh' 1e-8) outPrimalC
- let outPrimalSI = interpretOpen False input exprS
+ let outPrimalSI = interpretOpen False env input exprS
outPrimalSC <- liftIO $ primalSfun input
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
@@ -285,7 +285,7 @@ adTestGenFwd env envGenerator exprS =
testProperty "compile fwdAD" $ property $ do
input <- forAllWith (showEnv env) envGenerator
dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
- let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS)
+ let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS)
(outDNC1, outDNC2) <- liftIO $ dnfun dinput
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
@@ -317,10 +317,10 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
- let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
- (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
- (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
scChad = tanEScalars env $ toTanE env input gradChad0
scChadS = tanEScalars env $ toTanE env input gradChadS
scSChad = tanEScalars env $ toTanE env input gradSChad0