summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-22 23:49:34 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-22 23:49:34 +0200
commit6a0381f9c6cfc56ac805801bf4cefda8305ff055 (patch)
tree2dc1e7b77f2df5d65db852db3b9c53dfc4d76f7a /test
parent3847c6ae2d5eb581dac88629e7534aa42e143411 (diff)
Make test suite a little friendlier to debugging
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs12
1 files changed, 10 insertions, 2 deletions
diff --git a/test/Main.hs b/test/Main.hs
index a3fa484..e325b64 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -21,8 +21,11 @@ import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main
+import Debug.Trace
+
import Array
import AST
+import AST.Pretty
import CHAD
import CHAD.Types
import Data
@@ -53,7 +56,8 @@ gradientByCHAD = \env term input ->
dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
input1 = toPrimalE env input
(_out, grad) = interpretOpen input1 dterm
- in unTup vUnpair (d2e env) (Value grad)
+ in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $
+ unTup vUnpair (d2e env) (Value grad)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
@@ -73,6 +77,10 @@ gradientByCHAD = \env term input ->
STScal _ -> id
STAccum{} -> error "Accumulators not allowed in input program"
+ primalEnv :: SList STy env' -> SList STy (D1E env')
+ primalEnv SNil = SNil
+ primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
+
gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
where
@@ -226,7 +234,7 @@ adTestGen expr envGenerator = property $ do
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
tests :: IO Bool
-tests = checkParallel $ Group "AD"
+tests = checkSequential $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)