summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:06 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:06 +0100
commitbb84f6930702a02ba982795e2bb95a64d61f672b (patch)
tree910b2a119f9758115d1b59e45d558fb983a9286b
parent02db8c1929a25dda64e6cee7b7343833ee698f34 (diff)
Benchmark GMM
-rw-r--r--bench/Main.hs29
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/Example.hs6
-rw-r--r--src/Example/GMM.hs5
-rw-r--r--src/Example/Types.hs11
5 files changed, 43 insertions, 9 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index c62b0f2..5bb81ac 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -13,6 +13,7 @@ module Main where
import Control.DeepSeq
import Criterion.Main
import Data.Coerce
+import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
@@ -22,6 +23,8 @@ import CHAD.Top
import CHAD.Types
import Data
import Example
+import Example.GMM
+import Example.Types
import Interpreter
import Interpreter.Rep
import Simplify
@@ -82,8 +85,34 @@ makeNeuralInputs =
lay3 = Value (genArray (ShNil `ShCons` n2))
in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
+makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]
+makeGMMInputs =
+ let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1]))
+ -- these values are completely arbitrary lol
+ kN = 10
+ kD = 10
+ kK = 10
+ i2i64 = fromIntegral @Int @Int64
+ valpha = genArray (ShNil `ShCons` kK)
+ vM = genArray (ShNil `ShCons` kK `ShCons` kD)
+ vQ = genArray (ShNil `ShCons` kK `ShCons` kD)
+ vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
+ vX = genArray (ShNil `ShCons` kN `ShCons` kD)
+ vgamma = 0.42
+ vm = 2
+ k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
+ k2 = 0.5 * vgamma * vgamma
+ k3 = 0.42 -- don't feel like multigammaing today
+ in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
+ Value vm `SCons` vX `SCons`
+ vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
+ Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
+ SNil
+
main :: IO ()
main = defaultMain
[env (return makeNeuralInputs) $ \inputs ->
bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
+ ,env (return makeGMMInputs) $ \inputs ->
+ bench "gmm" (nf (\(inp, ctg) -> gradCHAD inp ctg (gmmObjective False)) (inputs, 1.0))
]
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 8817718..5306595 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -27,6 +27,7 @@ library
Example
Example.Format
Example.GMM
+ Example.Types
ForwardAD
ForwardAD.DualNumbers
ForwardAD.DualNumbers.Types
diff --git a/src/Example.hs b/src/Example.hs
index a08724b..390031e 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -19,6 +19,7 @@ import Simplify
import Debug.Trace
import Example.Format
+import Example.Types
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
@@ -110,8 +111,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $
let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
#b ! pair nil 3
-type R = TScal TF64
-
senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 = knownEnv
@@ -150,9 +149,6 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
let_ #inp #input $
layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
-type TVec = TArr (S Z)
-type TMat = TArr (S (S Z))
-
neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index 1db88bd..12bbd98 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -3,13 +3,10 @@
{-# LANGUAGE TypeApplications #-}
module Example.GMM where
+import Example.Types
import Language
-type R = TScal TF64
-type I64 = TScal TI64
-type TVec = TArr (S Z)
-type TMat = TArr (S (S Z))
-- N, D, K: integers > 0
-- alpha, M, Q, L: the active parameters
diff --git a/src/Example/Types.hs b/src/Example/Types.hs
new file mode 100644
index 0000000..d63159b
--- /dev/null
+++ b/src/Example/Types.hs
@@ -0,0 +1,11 @@
+{-# LANGUAGE DataKinds #-}
+module Example.Types where
+
+import AST
+import Data
+
+
+type R = TScal TF64
+type I64 = TScal TI64
+type TVec = TArr (S Z)
+type TMat = TArr (S (S Z))