summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-10 12:39:08 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-10 12:39:08 +0100
commit013e01e28aba090c065ed584671a65aa339ea51b (patch)
tree1595a8363fc181a13d41224e206d051d4e6a906b /test
parent9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff)
Test GMM; it fails
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs41
1 files changed, 33 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 2573a32..75ab11a 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -10,6 +10,7 @@
module Main where
import Data.Bifunctor
+import Data.Int (Int64)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
@@ -22,6 +23,7 @@ import AST.Pretty
import CHAD.Top
import CHAD.Types
import qualified Example
+import qualified Example.GMM as Example
import ForwardAD
import Interpreter
import Interpreter.Rep
@@ -150,14 +152,14 @@ adTestGen expr envGenerator = property $ do
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
- annotate (ppExpr knownEnv expr)
- annotate ppdterm
- annotate ppdterm_S
- diff ppdterm_S20 (==) ppdterm_S
- diff outChad closeIsh outChad_S
- diff outPrimal closeIsh outChad_S
- diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
- diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
+ -- annotate (ppExpr knownEnv expr)
+ -- annotate ppdterm
+ -- annotate ppdterm_S
+ diff ppdterm_S (==) ppdterm_S20
+ diff outChad_S closeIsh outChad
+ diff outChad_S closeIsh outPrimal
+ diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD
+ diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
@@ -221,6 +223,29 @@ tests = checkSequential $ Group "AD"
lay2 <- genLayer n1 n2
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
+
+ ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do
+ -- The input ranges here are completely arbitrary.
+ let tR = STScal STF64
+ kN <- Gen.integral (Range.linear 1 10)
+ kD <- Gen.integral (Range.linear 1 10)
+ kK <- Gen.integral (Range.linear 1 10)
+ let i2i64 = fromIntegral @Int @Int64
+ valpha <- genArray tR (ShNil `ShCons` kK)
+ vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
+ vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
+ vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ vm <- Gen.integral (Range.linear 0 5)
+ let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
+ k2 = 0.5 * vgamma * vgamma
+ k3 = 0.42 -- don't feel like multigammaing today
+ return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
+ Value vm `SCons` vX `SCons`
+ vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
+ Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
+ SNil))
]
main :: IO ()