summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-09 23:09:41 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-09 23:09:41 +0100
commite78a7cb73f33453a97fa12cfa8e5af07d1aa6eba (patch)
treef45f5505082353b559e7315659b1a60396e64458
parenta590b1414baec157da3a1f6c5684b1a3bce8ecaf (diff)
test: Also test pre-simplified term
-rw-r--r--test/Main.hs46
1 files changed, 31 insertions, 15 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 92dc446..83271fc 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -206,17 +206,26 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)
adTestGen :: forall env. KnownEnv env
=> TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
- let env = knownEnv @env in
+ let env = knownEnv @env
+ exprS = simplifyFix expr
+ in
withCompiled env expr $ \getprimalfun ->
+ withCompiled env (simplifyFix expr) $ \getprimalSfun ->
testGroup name
[testProperty "compile primal" $ property $ do
primalfun <- liftIO getprimalfun
+ primalSfun <- liftIO getprimalSfun
input <- forAllWith (showEnv env) envGenerator
+
let outPrimalI = interpretOpen False input expr
outPrimalC <- liftIO $ primalfun input
diff outPrimalI (closeIsh' 1e-8) outPrimalC
- ,withCompiled (dne env) (dfwdDN expr) $ \getdnfun ->
+ let outPrimalSI = interpretOpen False input exprS
+ outPrimalSC <- liftIO $ primalSfun input
+ diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
+
+ ,withCompiled (dne env) (dfwdDN exprS) $ \getdnfun ->
testProperty "compile fwdAD" $ property $ do
dnfun <- liftIO getdnfun
input <- forAllWith (showEnv env) envGenerator
@@ -226,45 +235,52 @@ adTestGen name expr envGenerator =
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
- ,withResource (makeFwdADArtifactCompile env expr) (\_ -> pure ()) $ \getfwdartifactC ->
+ ,withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \getfwdartifactC ->
testProperty "chad" $ property $ do
- primalfun <- liftIO getprimalfun
+ primalSfun <- liftIO getprimalSfun
fwdartifactC <- liftIO getfwdartifactC
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
dtermChadS = simplifyFix dtermChad0
- dtermChadS20 = simplifyN 20 dtermChad0
+ let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+ dtermSChadS = simplifyFix dtermSChad0
-- pack Text for less GC pressure (these values are retained for some reason)
- diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20))
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
+ diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
input <- forAllWith (showEnv env) envGenerator
+ outPrimal <- liftIO $ primalSfun input
let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
unpackGrad = unTup vUnpair (d2e env) . Value
- outPrimal <- liftIO $ primalfun input
+ let scFwd = envScalars env $ gradientByForward fwdartifactC input
let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
(outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
- gradChad0' = toTanE env input gradChad0
- gradChadS' = toTanE env input gradChadS
- scChad = envScalars env gradChad0'
- scChadS = envScalars env gradChadS'
- gradFwd = gradientByForward fwdartifactC input
- scFwd = envScalars env gradFwd
+ (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
+ (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
+ scChad = envScalars env $ toTanE env input gradChad0
+ scChadS = envScalars env $ toTanE env input gradChadS
+ scSChad = envScalars env $ toTanE env input gradSChad0
+ scSChadS = envScalars env $ toTanE env input gradSChadS
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- diff outChadS closeIsh outChad0
+ diff outChad0 closeIsh outPrimal
diff outChadS closeIsh outPrimal
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
]
where