From 8ac606b8f0c482679e9f017e6b2f0f33d58f9573 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 4 Mar 2025 23:21:53 +0100 Subject: Add some simplify flags infrastructure for debugging --- test/Main.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'test') diff --git a/test/Main.hs b/test/Main.hs index b3a0795..9ff82a1 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -44,8 +44,8 @@ import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) -simplifyWith :: SimplIters -> SList STy env -> Ex env t -> Ex env t -simplifyWith iters env | Dict <- envKnown env = +simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t +simplifyIters iters env | Dict <- envKnown env = case iters of SimplIters n -> simplifyN n SimplFix -> simplifyFix @@ -53,7 +53,7 @@ simplifyWith iters env | Dict <- envKnown env = -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = - let dterm = simplifyWith simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term + let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- cgit v1.2.3-70-g09d2