summaryrefslogtreecommitdiff
path: root/bench
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:06 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:06 +0100
commitbb84f6930702a02ba982795e2bb95a64d61f672b (patch)
tree910b2a119f9758115d1b59e45d558fb983a9286b /bench
parent02db8c1929a25dda64e6cee7b7343833ee698f34 (diff)
Benchmark GMM
Diffstat (limited to 'bench')
-rw-r--r--bench/Main.hs29
1 files changed, 29 insertions, 0 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index c62b0f2..5bb81ac 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -13,6 +13,7 @@ module Main where
import Control.DeepSeq
import Criterion.Main
import Data.Coerce
+import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
@@ -22,6 +23,8 @@ import CHAD.Top
import CHAD.Types
import Data
import Example
+import Example.GMM
+import Example.Types
import Interpreter
import Interpreter.Rep
import Simplify
@@ -82,8 +85,34 @@ makeNeuralInputs =
lay3 = Value (genArray (ShNil `ShCons` n2))
in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
+makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]
+makeGMMInputs =
+ let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1]))
+ -- these values are completely arbitrary lol
+ kN = 10
+ kD = 10
+ kK = 10
+ i2i64 = fromIntegral @Int @Int64
+ valpha = genArray (ShNil `ShCons` kK)
+ vM = genArray (ShNil `ShCons` kK `ShCons` kD)
+ vQ = genArray (ShNil `ShCons` kK `ShCons` kD)
+ vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
+ vX = genArray (ShNil `ShCons` kN `ShCons` kD)
+ vgamma = 0.42
+ vm = 2
+ k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
+ k2 = 0.5 * vgamma * vgamma
+ k3 = 0.42 -- don't feel like multigammaing today
+ in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
+ Value vm `SCons` vX `SCons`
+ vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
+ Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
+ SNil
+
main :: IO ()
main = defaultMain
[env (return makeNeuralInputs) $ \inputs ->
bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
+ ,env (return makeGMMInputs) $ \inputs ->
+ bench "gmm" (nf (\(inp, ctg) -> gradCHAD inp ctg (gmmObjective False)) (inputs, 1.0))
]