summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs3
1 files changed, 2 insertions, 1 deletions
diff --git a/test/Main.hs b/test/Main.hs
index d617228..7cb15d5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -24,6 +24,7 @@ import Hedgehog.Main
import Array
import AST
import AST.Pretty
+import CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
import qualified Example
@@ -41,7 +42,7 @@ data SimplIters = SimplIters Int | SimplFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
- let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term
+ let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
dterm | Dict <- envKnown env
= case simplIters of
SimplIters n -> simplifyN n dtermNonSimpl