diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-26 21:29:33 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-26 21:29:33 +0200 |
commit | 633302b54e90d4b34f4a717327167c196171a250 (patch) | |
tree | 8a737d7415867d72ef8cfcdc9e47f1d1f6da5433 /test | |
parent | 57d826d7e1fae089a3ec61da60d6f1ca1a4e49d2 (diff) |
Also test primal results
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs index c746807..11d15b4 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -64,14 +64,14 @@ diffCHAD = \simplIters env term -> makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (D2E env)) +gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> let dterm = diffCHAD simplIters env term input1 = toPrimalE env input - (_out, grad) = interpretOpen False input1 dterm - in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad)) + (out, grad) = interpretOpen False input1 dterm + in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad))) where toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') toPrimalE SNil SNil = SNil @@ -88,8 +88,8 @@ gradientByCHAD = \simplIters env term input -> STAccum{} -> error "Accumulators not allowed in input program" -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (TanE env)) -gradientByCHAD' = \simplIters env term input -> toTanE env input <$> gradientByCHAD simplIters env term input +gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) +gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil @@ -210,9 +210,10 @@ adTestGen :: forall env. KnownEnv env adTestGen expr envGenerator = property $ do let env = knownEnv @env input <- forAllWith (showEnv env) envGenerator - let gradFwd = gradientByForward knownEnv expr input - (ppdterm, gradCHAD) = gradientByCHAD' 0 knownEnv expr input - (ppdterm_S, gradCHAD_S) = gradientByCHAD' 20 knownEnv expr input + let outPrimal = interpretOpen False input expr + gradFwd = gradientByForward knownEnv expr input + (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' 0 knownEnv expr input + (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' 20 knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S @@ -220,6 +221,8 @@ adTestGen expr envGenerator = property $ do annotate (ppExpr knownEnv expr) annotate ppdterm annotate ppdterm_S + diff outChad closeIsh outChad_S + diff outPrimal closeIsh outChad_S diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S where |