diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-22 23:49:34 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-22 23:49:34 +0200 |
commit | 6a0381f9c6cfc56ac805801bf4cefda8305ff055 (patch) | |
tree | 2dc1e7b77f2df5d65db852db3b9c53dfc4d76f7a /test | |
parent | 3847c6ae2d5eb581dac88629e7534aa42e143411 (diff) |
Make test suite a little friendlier to debugging
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/test/Main.hs b/test/Main.hs index a3fa484..e325b64 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -21,8 +21,11 @@ import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Hedgehog.Main +import Debug.Trace + import Array import AST +import AST.Pretty import CHAD import CHAD.Types import Data @@ -53,7 +56,8 @@ gradientByCHAD = \env term input -> dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) input1 = toPrimalE env input (_out, grad) = interpretOpen input1 dterm - in unTup vUnpair (d2e env) (Value grad) + in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $ + unTup vUnpair (d2e env) (Value grad) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop @@ -73,6 +77,10 @@ gradientByCHAD = \env term input -> STScal _ -> id STAccum{} -> error "Accumulators not allowed in input program" + primalEnv :: SList STy env' -> SList STy (D1E env') + primalEnv SNil = SNil + primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env + gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) where @@ -226,7 +234,7 @@ adTestGen expr envGenerator = property $ do envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs tests :: IO Bool -tests = checkParallel $ Group "AD" +tests = checkSequential $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) |