summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/Main.hs58
1 files changed, 35 insertions, 23 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 79051de..f3aec68 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -345,17 +345,21 @@ adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList
adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr
dtermChadS = simplifyFix dtermChad0
+ dtermChadSUS = simplifyFix $ unMonoid dtermChadS
dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
dtermSChadS = simplifyFix dtermSChad0
+ dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
in
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->
+ withCompiled env dtermSChadSUS $ \dcompSChadSUS ->
testProperty testname $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
- -- pack Text for less GC pressure (these values are retained for some reason)
+ -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason)
diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
+ diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0)))
diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
+ diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0)))
input <- forAllWith (showEnv env) envGenerator
outPrimal <- evalIO $ primalSfun input
@@ -365,17 +369,21 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
- let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
- (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
- (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
- (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
- tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
- tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
- tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
- tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
-
- (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input)
- let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
+ (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
+ (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
+ tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+
+ (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input)
+ let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS
-- annotate (showEnv (d2e env) gradChad0)
-- annotate (showEnv (d2e env) gradChadS)
@@ -383,17 +391,21 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff outCompSChadS closeIsh outPrimal
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outChadSUS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outSChadSUS closeIsh outPrimal
+ diff outCompSChadSUS closeIsh outPrimal
let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
- diff tansChad closeIshE' tansFwd
- diff tansChadS closeIshE' tansFwd
- diff tansSChad closeIshE' tansFwd
- diff tansSChadS closeIshE' tansFwd
- diff tansCompSChadS closeIshE' tansFwd
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansChadSUS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansSChadSUS closeIshE' tansFwd
+ diff tansCompSChadSUS closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())