diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:57 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:57 +0100 | 
| commit | b8c162ce9cb1faeec621b751fff9aff46e022417 (patch) | |
| tree | 9c31700f34f9a1f1a67e0a73c880938130e87ee6 /test | |
| parent | bb84f6930702a02ba982795e2bb95a64d61f672b (diff) | |
Configuration for CHAD
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 3 | 
1 files changed, 2 insertions, 1 deletions
| diff --git a/test/Main.hs b/test/Main.hs index d617228..7cb15d5 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,6 +24,7 @@ import Hedgehog.Main  import Array  import AST  import AST.Pretty +import CHAD (defaultConfig)  import CHAD.Top  import CHAD.Types  import qualified Example @@ -41,7 +42,7 @@ data SimplIters = SimplIters Int | SimplFix  -- In addition to the gradient, also returns the pretty-printed differentiated term.  gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))  gradientByCHAD = \simplIters env term input -> -  let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term +  let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term        dterm | Dict <- envKnown env              = case simplIters of                  SimplIters n -> simplifyN n dtermNonSimpl | 
