From c50eeccd45b8f8fce50bc7f3eeffa9f4ab8a77a4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 11 Oct 2025 20:41:07 +0200 Subject: Test with pruneExpr --- test/Main.hs | 76 +++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 34 deletions(-) diff --git a/test/Main.hs b/test/Main.hs index a1fbf83..392855b 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -25,6 +25,7 @@ import Test.Framework import Array import AST hiding ((.>)) +import AST.Count (pruneExpr) import AST.Pretty import AST.UnMonoid import CHAD.Top @@ -349,17 +350,20 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS + dtermSChadSUSP = pruneExpr env dtermSChadSUS in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env dtermSChadSUS $ \dcompSChadSUS -> + withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) - diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0))) + let dtermChad20 = simplifyN 20 dtermChad0 + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20)) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20))) + let dtermSChad20 = simplifyN 20 dtermSChad0 + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20)) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -369,21 +373,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS - (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS - tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS - - (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input) - let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP + + (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input) + let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) @@ -391,21 +397,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outChadSUS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outSChadSUS closeIsh outPrimal - diff outCompSChadSUS closeIsh outPrimal + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outSChadSUSP closeIsh outPrimal + diff outCompSChadSUSP closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansChadSUS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansSChadSUS closeIshE' tansFwd - diff tansCompSChadSUS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansSChadSUSP closeIshE' tansFwd + diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -- cgit v1.2.3-70-g09d2