diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-10-22 23:49:34 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-22 23:49:34 +0200 | 
| commit | 6a0381f9c6cfc56ac805801bf4cefda8305ff055 (patch) | |
| tree | 2dc1e7b77f2df5d65db852db3b9c53dfc4d76f7a /test/Main.hs | |
| parent | 3847c6ae2d5eb581dac88629e7534aa42e143411 (diff) | |
Make test suite a little friendlier to debugging
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 12 | 
1 files changed, 10 insertions, 2 deletions
| diff --git a/test/Main.hs b/test/Main.hs index a3fa484..e325b64 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -21,8 +21,11 @@ import qualified Hedgehog.Gen as Gen  import qualified Hedgehog.Range as Range  import Hedgehog.Main +import Debug.Trace +  import Array  import AST +import AST.Pretty  import CHAD  import CHAD.Types  import Data @@ -53,7 +56,8 @@ gradientByCHAD = \env term input ->            dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)            input1 = toPrimalE env input            (_out, grad) = interpretOpen input1 dterm -      in unTup vUnpair (d2e env) (Value grad) +      in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $ +         unTup vUnpair (d2e env) (Value grad)    where      makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')      makeMergeDescr SNil = DTop @@ -73,6 +77,10 @@ gradientByCHAD = \env term input ->        STScal _ -> id        STAccum{} -> error "Accumulators not allowed in input program" +    primalEnv :: SList STy env' -> SList STy (D1E env') +    primalEnv SNil = SNil +    primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env +  gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)  gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)    where @@ -226,7 +234,7 @@ adTestGen expr envGenerator = property $ do      envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs  tests :: IO Bool -tests = checkParallel $ Group "AD" +tests = checkSequential $ Group "AD"    [("id", adTest $ fromNamed $ lambda #x $ body $ #x)    ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) | 
