summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-04-21 20:19:12 +0200
committerTom Smeding <t.j.smeding@uu.nl>2025-04-21 20:19:12 +0200
commit5adf77f813a5e79546674b82a855ea1d542931fc (patch)
treec7d49a4dba7054f9a1e74a350e5d8a80ccf28dc7 /test/Main.hs
parenta624136738fb1ad3bf801723b9afbf1132fad7f0 (diff)
test: Move generator helpers to top-level
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs92
1 files changed, 46 insertions, 46 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 3a6bc71..04246ce 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -33,6 +33,7 @@ import CHAD.Types.ToTan
import Compile
import qualified Example
import qualified Example.GMM as Example
+import Example.Types
import ForwardAD
import ForwardAD.DualNumbers
import Interpreter
@@ -41,9 +42,6 @@ import Language
import Simplify
-type R = TScal TF64
-
-
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
@@ -351,6 +349,45 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
+gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64])
+gen_gmm = do
+ -- The input ranges here are completely arbitrary.
+ let tR = STScal STF64
+ kN <- Gen.integral (Range.linear 1 8)
+ kD <- Gen.integral (Range.linear 1 8)
+ kK <- Gen.integral (Range.linear 1 8)
+ let i2i64 = fromIntegral @Int @Int64
+ valpha <- genArray tR (ShNil `ShCons` kK)
+ vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
+ vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
+ vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ vm <- Gen.integral (Range.linear 0 5)
+ let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
+ k2 = 0.5 * vgamma * vgamma
+ k3 = 0.42 -- don't feel like multigammaing today
+ return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
+ Value vm `SCons` vX `SCons`
+ vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
+ Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
+ SNil)
+
+gen_neural :: Gen (SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)])
+gen_neural = do
+ let tR = STScal STF64
+ let genLayer nin nout =
+ liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
+ <*> genArray tR (ShNil `ShCons` nout)
+ nin <- Gen.integral (Range.linear 1 10)
+ n1 <- Gen.integral (Range.linear 1 10)
+ n2 <- Gen.integral (Range.linear 1 10)
+ input <- genArray tR (ShNil `ShCons` nin)
+ lay1 <- genLayer nin n1
+ lay2 <- genLayer n1 n2
+ lay3 <- genArray tR (ShNil `ShCons` n2)
+ return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+
term_build1_sum :: Ex '[TArr N1 R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
@@ -480,9 +517,9 @@ tests_AD = testGroup "AD"
#L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0))))
42
- ,adTestGen "neural" Example.neural genNeural
+ ,adTestGen "neural" Example.neural gen_neural
- ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural
+ ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural
,adTestTp "logsumexp" (C "" 1) $
fromNamed $ lambda @(TArr N1 _) #vec $ body $
@@ -491,51 +528,14 @@ tests_AD = testGroup "AD"
,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec
- ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM
+ ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
- ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM
+ ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm
- ,adTestGen "gmm" (Example.gmmObjective False) genGMM
+ ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
- ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
+ ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm
]
- where
- genGMM = do
- -- The input ranges here are completely arbitrary.
- let tR = STScal STF64
- kN <- Gen.integral (Range.linear 1 8)
- kD <- Gen.integral (Range.linear 1 8)
- kK <- Gen.integral (Range.linear 1 8)
- let i2i64 = fromIntegral @Int @Int64
- valpha <- genArray tR (ShNil `ShCons` kK)
- vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
- vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
- vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
- vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
- vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
- vm <- Gen.integral (Range.linear 0 5)
- let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
- k2 = 0.5 * vgamma * vgamma
- k3 = 0.42 -- don't feel like multigammaing today
- return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
- Value vm `SCons` vX `SCons`
- vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
- Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
- SNil)
-
- genNeural = do
- let tR = STScal STF64
- let genLayer nin nout =
- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
- <*> genArray tR (ShNil `ShCons` nout)
- nin <- Gen.integral (Range.linear 1 10)
- n1 <- Gen.integral (Range.linear 1 10)
- n2 <- Gen.integral (Range.linear 1 10)
- input <- genArray tR (ShNil `ShCons` nin)
- lay1 <- genLayer nin n1
- lay2 <- genLayer n1 n2
- lay3 <- genArray tR (ShNil `ShCons` n2)
- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
main :: IO ()
main = defaultMain $ testGroup "All"