diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-17 23:21:51 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-17 23:21:51 +0100 |
commit | 146a846f799f63cd98eee2149c417686adba17a9 (patch) | |
tree | 3d2a5f13ebea86cc2618ea34da483bfdf77c27e8 | |
parent | 050ee6d17819e1353902f6ccb9c83a125638c375 (diff) |
test: Compile final gradient function (WIP)
This doesn't work yet because Compile doesn't yet support EFold1Inner
-rw-r--r-- | test/Main.hs | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/test/Main.hs b/test/Main.hs index 5cc00a1..19bd3e6 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -248,6 +248,7 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) -> TestTree adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> + withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS -> testProperty "chad" $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -268,28 +269,34 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let scFwd = tanEScalars env $ gradientByForward fwdartifactC input - let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS - scChad = tanEScalars env $ toTanE env input gradChad0 - scChadS = tanEScalars env $ toTanE env input gradChadS - scSChad = tanEScalars env $ toTanE env input gradSChad0 + scChad = tanEScalars env $ toTanE env input gradChad0 + scChadS = tanEScalars env $ toTanE env input gradChadS + scSChad = tanEScalars env $ toTanE env input gradSChad0 scSChadS = tanEScalars env $ toTanE env input gradSChadS + (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) + let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outCompSChadS closeIsh outPrimal + let closeIshList x y = and (zipWith closeIsh x y) + diff scChad closeIshList scFwd + diff scChadS closeIshList scFwd + diff scSChad closeIshList scFwd + diff scSChadS closeIshList scFwd + diff scCompSChadS closeIshList scFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) |