diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 23:21:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 23:21:53 +0100 |
commit | 8ac606b8f0c482679e9f017e6b2f0f33d58f9573 (patch) | |
tree | ffc02d6e14a9a6fba02d51cf78202839c3db3e83 | |
parent | c2cc922c7b56d17080aef1c6b41e2e98120dd7af (diff) |
Add some simplify flags infrastructure for debugging
-rw-r--r-- | src/Simplify.hs | 30 | ||||
-rw-r--r-- | test/Main.hs | 6 |
2 files changed, 29 insertions, 7 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index fb858e7..0132f85 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -7,7 +7,10 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Simplify (simplifyN, simplifyFix) where +module Simplify ( + simplifyN, simplifyFix, + SimplifyConfig(..), simplifyWith, simplifyFixWith, +) where import Data.Function (fix) import Data.Monoid (Any(..)) @@ -17,21 +20,40 @@ import AST.Count import Data +-- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some. +data SimplifyConfig = SimplifyConfig + +defaultSimplifyConfig :: SimplifyConfig +defaultSimplifyConfig = SimplifyConfig + simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t simplifyN 0 = id simplifyN n = simplifyN (n - 1) . simplify simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplify = let ?accumInScope = checkAccumInScope @env knownEnv in snd . simplify' +simplify = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = defaultSimplifyConfig + in snd . simplify' + +simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t +simplifyWith config = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = config + in snd . simplify' simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplifyFix = +simplifyFix = simplifyFixWith defaultSimplifyConfig + +simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t +simplifyFixWith config = let ?accumInScope = checkAccumInScope @env knownEnv + ?config = config in fix $ \loop e -> let (Any act, e') = simplify' e in if act then loop e' else e' -simplify' :: (?accumInScope :: Bool) => Ex env t -> (Any, Ex env t) +simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) simplify' = \case -- inlining ELet _ rhs body diff --git a/test/Main.hs b/test/Main.hs index b3a0795..9ff82a1 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -44,8 +44,8 @@ import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) -simplifyWith :: SimplIters -> SList STy env -> Ex env t -> Ex env t -simplifyWith iters env | Dict <- envKnown env = +simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t +simplifyIters iters env | Dict <- envKnown env = case iters of SimplIters n -> simplifyN n SimplFix -> simplifyFix @@ -53,7 +53,7 @@ simplifyWith iters env | Dict <- envKnown env = -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = - let dterm = simplifyWith simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term + let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) |