diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 23:09:41 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 23:09:41 +0100 |
commit | e78a7cb73f33453a97fa12cfa8e5af07d1aa6eba (patch) | |
tree | f45f5505082353b559e7315659b1a60396e64458 | |
parent | a590b1414baec157da3a1f6c5684b1a3bce8ecaf (diff) |
test: Also test pre-simplified term
-rw-r--r-- | test/Main.hs | 46 |
1 files changed, 31 insertions, 15 deletions
diff --git a/test/Main.hs b/test/Main.hs index 92dc446..83271fc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -206,17 +206,26 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) adTestGen :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = - let env = knownEnv @env in + let env = knownEnv @env + exprS = simplifyFix expr + in withCompiled env expr $ \getprimalfun -> + withCompiled env (simplifyFix expr) $ \getprimalSfun -> testGroup name [testProperty "compile primal" $ property $ do primalfun <- liftIO getprimalfun + primalSfun <- liftIO getprimalSfun input <- forAllWith (showEnv env) envGenerator + let outPrimalI = interpretOpen False input expr outPrimalC <- liftIO $ primalfun input diff outPrimalI (closeIsh' 1e-8) outPrimalC - ,withCompiled (dne env) (dfwdDN expr) $ \getdnfun -> + let outPrimalSI = interpretOpen False input exprS + outPrimalSC <- liftIO $ primalSfun input + diff outPrimalSI (closeIsh' 1e-8) outPrimalSC + + ,withCompiled (dne env) (dfwdDN exprS) $ \getdnfun -> testProperty "compile fwdAD" $ property $ do dnfun <- liftIO getdnfun input <- forAllWith (showEnv env) envGenerator @@ -226,45 +235,52 @@ adTestGen name expr envGenerator = diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 - ,withResource (makeFwdADArtifactCompile env expr) (\_ -> pure ()) $ \getfwdartifactC -> + ,withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \getfwdartifactC -> testProperty "chad" $ property $ do - primalfun <- liftIO getprimalfun + primalSfun <- liftIO getprimalSfun fwdartifactC <- liftIO getfwdartifactC annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr dtermChadS = simplifyFix dtermChad0 - dtermChadS20 = simplifyN 20 dtermChad0 + let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermSChadS = simplifyFix dtermSChad0 -- pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) input <- forAllWith (showEnv env) envGenerator + outPrimal <- liftIO $ primalSfun input let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) unpackGrad = unTup vUnpair (d2e env) . Value - outPrimal <- liftIO $ primalfun input + let scFwd = envScalars env $ gradientByForward fwdartifactC input let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS - gradChad0' = toTanE env input gradChad0 - gradChadS' = toTanE env input gradChadS - scChad = envScalars env gradChad0' - scChadS = envScalars env gradChadS' - gradFwd = gradientByForward fwdartifactC input - scFwd = envScalars env gradFwd + (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 + (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS + scChad = envScalars env $ toTanE env input gradChad0 + scChadS = envScalars env $ toTanE env input gradChadS + scSChad = envScalars env $ toTanE env input gradSChad0 + scSChadS = envScalars env $ toTanE env input gradSChadS -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - diff outChadS closeIsh outChad0 + diff outChad0 closeIsh outPrimal diff outChadS closeIsh outPrimal - diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd + diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd + diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd ] where |