diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 | 
| commit | 013e01e28aba090c065ed584671a65aa339ea51b (patch) | |
| tree | 1595a8363fc181a13d41224e206d051d4e6a906b /test | |
| parent | 9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff) | |
Test GMM; it fails
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 41 | 
1 files changed, 33 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs index 2573a32..75ab11a 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -10,6 +10,7 @@  module Main where  import Data.Bifunctor +import Data.Int (Int64)  import Data.List (intercalate)  import Hedgehog  import qualified Hedgehog.Gen as Gen @@ -22,6 +23,7 @@ import AST.Pretty  import CHAD.Top  import CHAD.Types  import qualified Example +import qualified Example.GMM as Example  import ForwardAD  import Interpreter  import Interpreter.Rep @@ -150,14 +152,14 @@ adTestGen expr envGenerator = property $ do        scCHAD = envScalars env gradCHAD        scCHAD_S = envScalars env gradCHAD_S    annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) -  annotate (ppExpr knownEnv expr) -  annotate ppdterm -  annotate ppdterm_S -  diff ppdterm_S20 (==) ppdterm_S -  diff outChad closeIsh outChad_S -  diff outPrimal closeIsh outChad_S -  diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S -  diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S +  -- annotate (ppExpr knownEnv expr) +  -- annotate ppdterm +  -- annotate ppdterm_S +  diff ppdterm_S (==) ppdterm_S20 +  diff outChad_S closeIsh outChad +  diff outChad_S closeIsh outPrimal +  diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD +  diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd    where      envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]      envScalars SNil SNil = [] @@ -221,6 +223,29 @@ tests = checkSequential $ Group "AD"        lay2 <- genLayer n1 n2        lay3 <- genArray tR (ShNil `ShCons` n2)        return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + +  ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do +      -- The input ranges here are completely arbitrary. +      let tR = STScal STF64 +      kN <- Gen.integral (Range.linear 1 10) +      kD <- Gen.integral (Range.linear 1 10) +      kK <- Gen.integral (Range.linear 1 10) +      let i2i64 = fromIntegral @Int @Int64 +      valpha <- genArray tR (ShNil `ShCons` kK) +      vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) +      vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) +      vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) +      vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) +      vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) +      vm <- Gen.integral (Range.linear 0 5) +      let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) +          k2 = 0.5 * vgamma * vgamma +          k3 = 0.42  -- don't feel like multigammaing today +      return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` +              Value vm `SCons` vX `SCons` +              vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` +              Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` +              SNil))    ]  main :: IO ()  | 
