From 146a846f799f63cd98eee2149c417686adba17a9 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 17 Mar 2025 23:21:51 +0100 Subject: test: Compile final gradient function (WIP) This doesn't work yet because Compile doesn't yet support EFold1Inner --- test/Main.hs | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/test/Main.hs b/test/Main.hs index 5cc00a1..19bd3e6 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -248,6 +248,7 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) -> TestTree adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> + withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS -> testProperty "chad" $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -268,28 +269,34 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let scFwd = tanEScalars env $ gradientByForward fwdartifactC input - let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS - scChad = tanEScalars env $ toTanE env input gradChad0 - scChadS = tanEScalars env $ toTanE env input gradChadS - scSChad = tanEScalars env $ toTanE env input gradSChad0 + scChad = tanEScalars env $ toTanE env input gradChad0 + scChadS = tanEScalars env $ toTanE env input gradChadS + scSChad = tanEScalars env $ toTanE env input gradSChad0 scSChadS = tanEScalars env $ toTanE env input gradSChadS + (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) + let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outCompSChadS closeIsh outPrimal + let closeIshList x y = and (zipWith closeIsh x y) + diff scChad closeIshList scFwd + diff scChadS closeIshList scFwd + diff scSChad closeIshList scFwd + diff scSChadS closeIshList scFwd + diff scCompSChadS closeIshList scFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -- cgit v1.2.3-70-g09d2