summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-26 21:29:33 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-26 21:29:33 +0200
commit633302b54e90d4b34f4a717327167c196171a250 (patch)
tree8a737d7415867d72ef8cfcdc9e47f1d1f6da5433 /test
parent57d826d7e1fae089a3ec61da60d6f1ca1a4e49d2 (diff)
Also test primal results
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs19
1 files changed, 11 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs
index c746807..11d15b4 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -64,14 +64,14 @@ diffCHAD = \simplIters env term ->
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (D2E env))
+gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
let dterm = diffCHAD simplIters env term
input1 = toPrimalE env input
- (_out, grad) = interpretOpen False input1 dterm
- in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad))
+ (out, grad) = interpretOpen False input1 dterm
+ in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad)))
where
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
toPrimalE SNil SNil = SNil
@@ -88,8 +88,8 @@ gradientByCHAD = \simplIters env term input ->
STAccum{} -> error "Accumulators not allowed in input program"
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (TanE env))
-gradientByCHAD' = \simplIters env term input -> toTanE env input <$> gradientByCHAD simplIters env term input
+gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
+gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
@@ -210,9 +210,10 @@ adTestGen :: forall env. KnownEnv env
adTestGen expr envGenerator = property $ do
let env = knownEnv @env
input <- forAllWith (showEnv env) envGenerator
- let gradFwd = gradientByForward knownEnv expr input
- (ppdterm, gradCHAD) = gradientByCHAD' 0 knownEnv expr input
- (ppdterm_S, gradCHAD_S) = gradientByCHAD' 20 knownEnv expr input
+ let outPrimal = interpretOpen False input expr
+ gradFwd = gradientByForward knownEnv expr input
+ (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' 0 knownEnv expr input
+ (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' 20 knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
@@ -220,6 +221,8 @@ adTestGen expr envGenerator = property $ do
annotate (ppExpr knownEnv expr)
annotate ppdterm
annotate ppdterm_S
+ diff outChad closeIsh outChad_S
+ diff outPrimal closeIsh outChad_S
diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
where