From bd5d0458017862b984b9caf0975c135d154e8515 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 18 Apr 2025 12:53:43 +0200 Subject: pretty: Print arguments of open expression --- test/Main.hs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'test') diff --git a/test/Main.hs b/test/Main.hs index 20b4ef0..3a6bc71 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -57,7 +57,7 @@ simplifyIters iters env | Dict <- envKnown env = gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False input dterm + (out, grad) = interpretOpen False env input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. @@ -232,7 +232,7 @@ compileTestGen name expr envGenerator = in withCompiled env expr $ \fun -> testProperty name $ property $ do input <- forAllWith (showEnv env) envGenerator - let resI = interpretOpen False input expr + let resI = interpretOpen False env input expr resC <- liftIO $ fun input let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y diff (TypedValue t resI) cmp (TypedValue t resC) @@ -269,11 +269,11 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = testProperty "compile primal" $ property $ do input <- forAllWith (showEnv env) envGenerator - let outPrimalI = interpretOpen False input expr + let outPrimalI = interpretOpen False env input expr outPrimalC <- liftIO $ primalfun input diff outPrimalI (closeIsh' 1e-8) outPrimalC - let outPrimalSI = interpretOpen False input exprS + let outPrimalSI = interpretOpen False env input exprS outPrimalSC <- liftIO $ primalSfun input diff outPrimalSI (closeIsh' 1e-8) outPrimalSC @@ -285,7 +285,7 @@ adTestGenFwd env envGenerator exprS = testProperty "compile fwdAD" $ property $ do input <- forAllWith (showEnv env) envGenerator dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input - let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS) + let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS) (outDNC1, outDNC2) <- liftIO $ dnfun dinput diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 @@ -317,10 +317,10 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let scFwd = tanEScalars env $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS scChad = tanEScalars env $ toTanE env input gradChad0 scChadS = tanEScalars env $ toTanE env input gradChadS scSChad = tanEScalars env $ toTanE env input gradSChad0 -- cgit v1.2.3-70-g09d2