diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:06 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:06 +0100 |
commit | bb84f6930702a02ba982795e2bb95a64d61f672b (patch) | |
tree | 910b2a119f9758115d1b59e45d558fb983a9286b | |
parent | 02db8c1929a25dda64e6cee7b7343833ee698f34 (diff) |
Benchmark GMM
-rw-r--r-- | bench/Main.hs | 29 | ||||
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/Example.hs | 6 | ||||
-rw-r--r-- | src/Example/GMM.hs | 5 | ||||
-rw-r--r-- | src/Example/Types.hs | 11 |
5 files changed, 43 insertions, 9 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index c62b0f2..5bb81ac 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -13,6 +13,7 @@ module Main where import Control.DeepSeq import Criterion.Main import Data.Coerce +import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) @@ -22,6 +23,8 @@ import CHAD.Top import CHAD.Types import Data import Example +import Example.GMM +import Example.Types import Interpreter import Interpreter.Rep import Simplify @@ -82,8 +85,34 @@ makeNeuralInputs = lay3 = Value (genArray (ShNil `ShCons` n2)) in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil +makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] +makeGMMInputs = + let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1])) + -- these values are completely arbitrary lol + kN = 10 + kD = 10 + kK = 10 + i2i64 = fromIntegral @Int @Int64 + valpha = genArray (ShNil `ShCons` kK) + vM = genArray (ShNil `ShCons` kK `ShCons` kD) + vQ = genArray (ShNil `ShCons` kK `ShCons` kD) + vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) + vX = genArray (ShNil `ShCons` kN `ShCons` kD) + vgamma = 0.42 + vm = 2 + k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) + k2 = 0.5 * vgamma * vgamma + k3 = 0.42 -- don't feel like multigammaing today + in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` + Value vm `SCons` vX `SCons` + vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` + Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` + SNil + main :: IO () main = defaultMain [env (return makeNeuralInputs) $ \inputs -> bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0)) + ,env (return makeGMMInputs) $ \inputs -> + bench "gmm" (nf (\(inp, ctg) -> gradCHAD inp ctg (gmmObjective False)) (inputs, 1.0)) ] diff --git a/chad-fast.cabal b/chad-fast.cabal index 8817718..5306595 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -27,6 +27,7 @@ library Example Example.Format Example.GMM + Example.Types ForwardAD ForwardAD.DualNumbers ForwardAD.DualNumbers.Types diff --git a/src/Example.hs b/src/Example.hs index a08724b..390031e 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -19,6 +19,7 @@ import Simplify import Debug.Trace import Example.Format +import Example.Types -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) @@ -110,8 +111,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ #b ! pair nil 3 -type R = TScal TF64 - senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] senv7 = knownEnv @@ -150,9 +149,6 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ let_ #inp #input $ layer (STPair (STPair (STPair STNil tpair) tpair) tpair) -type TVec = TArr (S Z) -type TMat = TArr (S (S Z)) - neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index 1db88bd..12bbd98 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -3,13 +3,10 @@ {-# LANGUAGE TypeApplications #-} module Example.GMM where +import Example.Types import Language -type R = TScal TF64 -type I64 = TScal TI64 -type TVec = TArr (S Z) -type TMat = TArr (S (S Z)) -- N, D, K: integers > 0 -- alpha, M, Q, L: the active parameters diff --git a/src/Example/Types.hs b/src/Example/Types.hs new file mode 100644 index 0000000..d63159b --- /dev/null +++ b/src/Example/Types.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE DataKinds #-} +module Example.Types where + +import AST +import Data + + +type R = TScal TF64 +type I64 = TScal TI64 +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) |