diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
commit | 013e01e28aba090c065ed584671a65aa339ea51b (patch) | |
tree | 1595a8363fc181a13d41224e206d051d4e6a906b /test | |
parent | 9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff) |
Test GMM; it fails
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs index 2573a32..75ab11a 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -10,6 +10,7 @@ module Main where import Data.Bifunctor +import Data.Int (Int64) import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen @@ -22,6 +23,7 @@ import AST.Pretty import CHAD.Top import CHAD.Types import qualified Example +import qualified Example.GMM as Example import ForwardAD import Interpreter import Interpreter.Rep @@ -150,14 +152,14 @@ adTestGen expr envGenerator = property $ do scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) - annotate (ppExpr knownEnv expr) - annotate ppdterm - annotate ppdterm_S - diff ppdterm_S20 (==) ppdterm_S - diff outChad closeIsh outChad_S - diff outPrimal closeIsh outChad_S - diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S - diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S + -- annotate (ppExpr knownEnv expr) + -- annotate ppdterm + -- annotate ppdterm_S + diff ppdterm_S (==) ppdterm_S20 + diff outChad_S closeIsh outChad + diff outChad_S closeIsh outPrimal + diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD + diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] @@ -221,6 +223,29 @@ tests = checkSequential $ Group "AD" lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + + ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do + -- The input ranges here are completely arbitrary. + let tR = STScal STF64 + kN <- Gen.integral (Range.linear 1 10) + kD <- Gen.integral (Range.linear 1 10) + kK <- Gen.integral (Range.linear 1 10) + let i2i64 = fromIntegral @Int @Int64 + valpha <- genArray tR (ShNil `ShCons` kK) + vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) + vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) + vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vm <- Gen.integral (Range.linear 0 5) + let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) + k2 = 0.5 * vgamma * vgamma + k3 = 0.42 -- don't feel like multigammaing today + return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` + Value vm `SCons` vX `SCons` + vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` + Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` + SNil)) ] main :: IO () |