summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs6
1 files changed, 3 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs
index b3a0795..9ff82a1 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -44,8 +44,8 @@ import Simplify
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
-simplifyWith :: SimplIters -> SList STy env -> Ex env t -> Ex env t
-simplifyWith iters env | Dict <- envKnown env =
+simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t
+simplifyIters iters env | Dict <- envKnown env =
case iters of
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
@@ -53,7 +53,7 @@ simplifyWith iters env | Dict <- envKnown env =
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
- let dterm = simplifyWith simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
+ let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))