summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-20 00:02:11 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-20 00:02:11 +0200
commit59ea6579c0cceeecaef7c27e39aab39828a4fbeb (patch)
tree449236067e6a3b2d894623fc90506fdde46db301 /test/Main.hs
parenta4b3eb76acbec30ffeae119a4dc6e4c9f64396fe (diff)
WIP parallel test suitetest-parallel
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs31
1 files changed, 19 insertions, 12 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 0a57cbf..4cdab1c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -20,6 +20,9 @@ import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import Hedgehog
import qualified Hedgehog.Gen as Gen
+import qualified Hedgehog.Internal.Gen as IGen
+import qualified Hedgehog.Internal.Tree as ITree
+import qualified Hedgehog.Internal.Seed as ISeed
import qualified Hedgehog.Range as Range
import Test.Framework
@@ -40,6 +43,7 @@ import Interpreter
import Interpreter.Rep
import Language
import Simplify
+import Data.Maybe (fromJust)
data TypedValue t = TypedValue (STy t) (Rep t)
@@ -63,18 +67,18 @@ simplifyIters iters env | Dict <- envKnown env =
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
--- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
-gradientByCHAD simplIters env term input =
- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- (out, grad) = interpretOpen False env input dterm
- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
+-- -- In addition to the gradient, also returns the pretty-printed differentiated term.
+-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
+-- gradientByCHAD simplIters env term input =
+-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
+-- (out, grad) = interpretOpen False env input dterm
+-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
--- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
-gradientByCHAD' simplIters env term input =
- second (second (toTanE env input)) $
- gradientByCHAD simplIters env term input
+-- -- In addition to the gradient, also returns the pretty-printed differentiated term.
+-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
+-- gradientByCHAD' simplIters env term input =
+-- second (second (toTanE env input)) $
+-- gradientByCHAD simplIters env term input
gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
@@ -302,7 +306,7 @@ adTestGen name expr envGenerator =
exprS = simplifyFix expr
in withCompiled env expr $ \primalfun ->
withCompiled env (simplifyFix expr) $ \primalSfun ->
- testGroupCollapse name
+ groupSetCollapse $ groupSetSequential $ testGroup name
[adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
,adTestGenFwd env envGenerator exprS
,testGroup "chad"
@@ -661,6 +665,9 @@ tests_AD = testGroup "AD"
,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
]
+gmminp :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]
+gmminp = ITree.treeValue $ fromJust $ IGen.evalGen 30 (ISeed.from 3) gen_gmm
+
main :: IO ()
main = defaultMain $ testGroup "All"
[tests_Compile