diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-17 23:21:51 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-17 23:21:51 +0100 | 
| commit | 146a846f799f63cd98eee2149c417686adba17a9 (patch) | |
| tree | 3d2a5f13ebea86cc2618ea34da483bfdf77c27e8 | |
| parent | 050ee6d17819e1353902f6ccb9c83a125638c375 (diff) | |
test: Compile final gradient function (WIP)
This doesn't work yet because Compile doesn't yet support EFold1Inner
| -rw-r--r-- | test/Main.hs | 33 | 
1 files changed, 20 insertions, 13 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 5cc00a1..19bd3e6 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -248,6 +248,7 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)                -> TestTree  adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =    withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> +  withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS ->      testProperty "chad" $ property $ do        annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -268,28 +269,34 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =        let scFwd = tanEScalars env $ gradientByForward fwdartifactC input -      let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 -          (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS +      let (outChad0 , gradChad0)  = second unpackGrad $ interpretOpen False input dtermChad0 +          (outChadS , gradChadS)  = second unpackGrad $ interpretOpen False input dtermChadS            (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0            (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS -          scChad = tanEScalars env $ toTanE env input gradChad0 -          scChadS = tanEScalars env $ toTanE env input gradChadS -          scSChad = tanEScalars env $ toTanE env input gradSChad0 +          scChad   = tanEScalars env $ toTanE env input gradChad0 +          scChadS  = tanEScalars env $ toTanE env input gradChadS +          scSChad  = tanEScalars env $ toTanE env input gradSChad0            scSChadS = tanEScalars env $ toTanE env input gradSChadS +      (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) +      let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS +        -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))        -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))        -- annotate (ppExpr knownEnv expr)        -- annotate (ppExpr env dtermChad0)        -- annotate (ppExpr env dtermChadS) -      diff outChad0 closeIsh outPrimal -      diff outChadS closeIsh outPrimal -      diff outSChad0 closeIsh outPrimal -      diff outSChadS closeIsh outPrimal -      diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd -      diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd -      diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd -      diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd +      diff outChad0      closeIsh outPrimal +      diff outChadS      closeIsh outPrimal +      diff outSChad0     closeIsh outPrimal +      diff outSChadS     closeIsh outPrimal +      diff outCompSChadS closeIsh outPrimal +      let closeIshList x y = and (zipWith closeIsh x y) +      diff scChad       closeIshList scFwd +      diff scChadS      closeIshList scFwd +      diff scSChad      closeIshList scFwd +      diff scSChadS     closeIshList scFwd +      diff scCompSChadS closeIshList scFwd  withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree  withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) | 
