summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-17 23:21:51 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-17 23:21:51 +0100
commit146a846f799f63cd98eee2149c417686adba17a9 (patch)
tree3d2a5f13ebea86cc2618ea34da483bfdf77c27e8
parent050ee6d17819e1353902f6ccb9c83a125638c375 (diff)
test: Compile final gradient function (WIP)
This doesn't work yet because Compile doesn't yet support EFold1Inner
-rw-r--r--test/Main.hs33
1 files changed, 20 insertions, 13 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 5cc00a1..19bd3e6 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -248,6 +248,7 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
-> TestTree
adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
+ withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS ->
testProperty "chad" $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
@@ -268,28 +269,34 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
- let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
(outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
(outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
- scChad = tanEScalars env $ toTanE env input gradChad0
- scChadS = tanEScalars env $ toTanE env input gradChadS
- scSChad = tanEScalars env $ toTanE env input gradSChad0
+ scChad = tanEScalars env $ toTanE env input gradChad0
+ scChadS = tanEScalars env $ toTanE env input gradChadS
+ scSChad = tanEScalars env $ toTanE env input gradSChad0
scSChadS = tanEScalars env $ toTanE env input gradSChadS
+ (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input)
+ let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS
+
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outCompSChadS closeIsh outPrimal
+ let closeIshList x y = and (zipWith closeIsh x y)
+ diff scChad closeIshList scFwd
+ diff scChadS closeIshList scFwd
+ diff scSChad closeIshList scFwd
+ diff scSChadS closeIshList scFwd
+ diff scCompSChadS closeIshList scFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())