aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs74
1 files changed, 41 insertions, 33 deletions
diff --git a/test/Main.hs b/test/Main.hs
index a1fbf83..392855b 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -25,6 +25,7 @@ import Test.Framework
import Array
import AST hiding ((.>))
+import AST.Count (pruneExpr)
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
@@ -349,17 +350,20 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
dtermSChadS = simplifyFix dtermSChad0
dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
+ dtermSChadSUSP = pruneExpr env dtermSChadSUS
in
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- withCompiled env dtermSChadSUS $ \dcompSChadSUS ->
+ withCompiled env dtermSChadSUSP $ \dcompSChadSUSP ->
testProperty testname $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
-- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason)
- diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
- diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0)))
- diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
- diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0)))
+ let dtermChad20 = simplifyN 20 dtermChad0
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20))
+ diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20)))
+ let dtermSChad20 = simplifyN 20 dtermSChad0
+ diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20))
+ diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20)))
input <- forAllWith (showEnv env) envGenerator
outPrimal <- evalIO $ primalSfun input
@@ -369,21 +373,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
- let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
- (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
- (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
- (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
- (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
- (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
- tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
- tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
- tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
- tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
- tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
- tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
+ (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
+ (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
+ (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
+ tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+ tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP
- (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input)
- let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS
+ (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input)
+ let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP
-- annotate (showEnv (d2e env) gradChad0)
-- annotate (showEnv (d2e env) gradChadS)
@@ -391,21 +397,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outChadSUS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff outSChadSUS closeIsh outPrimal
- diff outCompSChadSUS closeIsh outPrimal
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outChadSUS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outSChadSUS closeIsh outPrimal
+ diff outSChadSUSP closeIsh outPrimal
+ diff outCompSChadSUSP closeIsh outPrimal
let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
- diff tansChad closeIshE' tansFwd
- diff tansChadS closeIshE' tansFwd
- diff tansChadSUS closeIshE' tansFwd
- diff tansSChad closeIshE' tansFwd
- diff tansSChadS closeIshE' tansFwd
- diff tansSChadSUS closeIshE' tansFwd
- diff tansCompSChadSUS closeIshE' tansFwd
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansChadSUS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansSChadSUS closeIshE' tansFwd
+ diff tansSChadSUSP closeIshE' tansFwd
+ diff tansCompSChadSUSP closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())