summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-04 23:21:53 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-04 23:21:53 +0100
commit8ac606b8f0c482679e9f017e6b2f0f33d58f9573 (patch)
treeffc02d6e14a9a6fba02d51cf78202839c3db3e83
parentc2cc922c7b56d17080aef1c6b41e2e98120dd7af (diff)
Add some simplify flags infrastructure for debugging
-rw-r--r--src/Simplify.hs30
-rw-r--r--test/Main.hs6
2 files changed, 29 insertions, 7 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index fb858e7..0132f85 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -7,7 +7,10 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-module Simplify (simplifyN, simplifyFix) where
+module Simplify (
+ simplifyN, simplifyFix,
+ SimplifyConfig(..), simplifyWith, simplifyFixWith,
+) where
import Data.Function (fix)
import Data.Monoid (Any(..))
@@ -17,21 +20,40 @@ import AST.Count
import Data
+-- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.
+data SimplifyConfig = SimplifyConfig
+
+defaultSimplifyConfig :: SimplifyConfig
+defaultSimplifyConfig = SimplifyConfig
+
simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify
simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplify = let ?accumInScope = checkAccumInScope @env knownEnv in snd . simplify'
+simplify =
+ let ?accumInScope = checkAccumInScope @env knownEnv
+ ?config = defaultSimplifyConfig
+ in snd . simplify'
+
+simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
+simplifyWith config =
+ let ?accumInScope = checkAccumInScope @env knownEnv
+ ?config = config
+ in snd . simplify'
simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplifyFix =
+simplifyFix = simplifyFixWith defaultSimplifyConfig
+
+simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
+simplifyFixWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
+ ?config = config
in fix $ \loop e ->
let (Any act, e') = simplify' e
in if act then loop e' else e'
-simplify' :: (?accumInScope :: Bool) => Ex env t -> (Any, Ex env t)
+simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t)
simplify' = \case
-- inlining
ELet _ rhs body
diff --git a/test/Main.hs b/test/Main.hs
index b3a0795..9ff82a1 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -44,8 +44,8 @@ import Simplify
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
-simplifyWith :: SimplIters -> SList STy env -> Ex env t -> Ex env t
-simplifyWith iters env | Dict <- envKnown env =
+simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t
+simplifyIters iters env | Dict <- envKnown env =
case iters of
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
@@ -53,7 +53,7 @@ simplifyWith iters env | Dict <- envKnown env =
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
- let dterm = simplifyWith simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
+ let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))