diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-21 20:19:12 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-21 20:19:12 +0200 |
commit | 5adf77f813a5e79546674b82a855ea1d542931fc (patch) | |
tree | c7d49a4dba7054f9a1e74a350e5d8a80ccf28dc7 /test/Main.hs | |
parent | a624136738fb1ad3bf801723b9afbf1132fad7f0 (diff) |
test: Move generator helpers to top-level
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 92 |
1 files changed, 46 insertions, 46 deletions
diff --git a/test/Main.hs b/test/Main.hs index 3a6bc71..04246ce 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -33,6 +33,7 @@ import CHAD.Types.ToTan import Compile import qualified Example import qualified Example.GMM as Example +import Example.Types import ForwardAD import ForwardAD.DualNumbers import Interpreter @@ -41,9 +42,6 @@ import Language import Simplify -type R = TScal TF64 - - data SimplIters = SimplIters Int | SimplFix deriving (Show) @@ -351,6 +349,45 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) +gen_gmm = do + -- The input ranges here are completely arbitrary. + let tR = STScal STF64 + kN <- Gen.integral (Range.linear 1 8) + kD <- Gen.integral (Range.linear 1 8) + kK <- Gen.integral (Range.linear 1 8) + let i2i64 = fromIntegral @Int @Int64 + valpha <- genArray tR (ShNil `ShCons` kK) + vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) + vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) + vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vm <- Gen.integral (Range.linear 0 5) + let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) + k2 = 0.5 * vgamma * vgamma + k3 = 0.42 -- don't feel like multigammaing today + return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` + Value vm `SCons` vX `SCons` + vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` + Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` + SNil) + +gen_neural :: Gen (SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]) +gen_neural = do + let tR = STScal STF64 + let genLayer nin nout = + liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + <*> genArray tR (ShNil `ShCons` nout) + nin <- Gen.integral (Range.linear 1 10) + n1 <- Gen.integral (Range.linear 1 10) + n2 <- Gen.integral (Range.linear 1 10) + input <- genArray tR (ShNil `ShCons` nin) + lay1 <- genLayer nin n1 + lay2 <- genLayer n1 n2 + lay3 <- genArray tR (ShNil `ShCons` n2) + return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) + term_build1_sum :: Ex '[TArr N1 R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ @@ -480,9 +517,9 @@ tests_AD = testGroup "AD" #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0)))) 42 - ,adTestGen "neural" Example.neural genNeural + ,adTestGen "neural" Example.neural gen_neural - ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural + ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural ,adTestTp "logsumexp" (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ @@ -491,51 +528,14 @@ tests_AD = testGroup "AD" ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec - ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM + ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm - ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM + ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm - ,adTestGen "gmm" (Example.gmmObjective False) genGMM + ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm - ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM + ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm ] - where - genGMM = do - -- The input ranges here are completely arbitrary. - let tR = STScal STF64 - kN <- Gen.integral (Range.linear 1 8) - kD <- Gen.integral (Range.linear 1 8) - kK <- Gen.integral (Range.linear 1 8) - let i2i64 = fromIntegral @Int @Int64 - valpha <- genArray tR (ShNil `ShCons` kK) - vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) - vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) - vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) - vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) - vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - vm <- Gen.integral (Range.linear 0 5) - let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) - k2 = 0.5 * vgamma * vgamma - k3 = 0.42 -- don't feel like multigammaing today - return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` - Value vm `SCons` vX `SCons` - vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` - Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` - SNil) - - genNeural = do - let tR = STScal STF64 - let genLayer nin nout = - liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) - <*> genArray tR (ShNil `ShCons` nout) - nin <- Gen.integral (Range.linear 1 10) - n1 <- Gen.integral (Range.linear 1 10) - n2 <- Gen.integral (Range.linear 1 10) - input <- genArray tR (ShNil `ShCons` nin) - lay1 <- genLayer nin n1 - lay2 <- genLayer n1 n2 - lay3 <- genArray tR (ShNil `ShCons` n2) - return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) main :: IO () main = defaultMain $ testGroup "All" |