summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 22:59:30 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 22:59:30 +0100
commit33e0ed21603cbd85d6aba6548811db27480647db (patch)
tree10cedc6821e2ef78a5d35013cf3dcd7855ab04e1
parentd4d4473ee229674f73929c0860a7e29302330361 (diff)
Most of GMM
-rw-r--r--bench/Bench/GMM.hs108
-rw-r--r--chad-fast.cabal11
2 files changed, 116 insertions, 3 deletions
diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs
new file mode 100644
index 0000000..ebbbe1e
--- /dev/null
+++ b/bench/Bench/GMM.hs
@@ -0,0 +1,108 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE TypeApplications #-}
+module Bench.GMM where
+
+import AST
+import Data
+import Language
+
+
+type R = TScal TF64
+type I64 = TScal TI64
+type TVec = TArr (S Z)
+type TMat = TArr (S (S Z))
+
+-- N, D, K: integers > 0
+-- alpha, M, Q, L: the active parameters
+-- X: inactive data
+-- m: integer
+-- k1: 1/2 N D log(2 pi)
+-- k2: 1/2 gamma^2
+-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D))
+-- where n' = D + m + 1
+--
+-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m.
+--
+-- See:
+-- - "A benchmark of selected algorithmic differentiation tools on some problems
+-- in computer vision and machine learning". Optim. Methods Softw. 33(4-6):
+-- 889-906 (2018).
+-- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651>
+-- <https://github.com/microsoft/ADBench>
+-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”.
+-- Master thesis at Utrecht University.
+-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
+-- <https://tomsmeding.com/f/master.pdf>
+objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
+objective = fromNamed $
+ lambda #N $ lambda #D $ lambda #K $
+ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $
+ lambda #X $ lambda #m $
+ lambda #k1 $ lambda #k2 $ lambda #k3 $
+ body $
+ let -- We have:
+ -- sum (exp (x - max(x)))
+ -- = sum (exp x / exp (max(x)))
+ -- = sum (exp x) / exp (max(x))
+ -- Hence:
+ -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*)
+ --
+ -- So:
+ -- d/dxi log (sum (exp x))
+ -- = 1/(sum (exp x)) * d/dxi sum (exp x)
+ -- = 1/(sum (exp x)) * sum (d/dxi exp x)
+ -- = 1/(sum (exp x)) * exp xi
+ -- = exp xi / sum (exp x)
+ -- (by (*))
+ -- = exp xi / (sum (exp (x - max(x))) * exp (max(x)))
+ -- = exp (xi - max(x)) / sum (exp (x - max(x)))
+ logsumexp' = lambda @(TVec R) #vec $ body $
+ custom (#_ :-> #v :->
+ let_ #m (idx0 (maximum1i #v)) $
+ log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m)
+ (#_ :-> #v :->
+ let_ #m (idx0 (maximum1i #v)) $
+ let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $
+ let_ #s (idx0 (sum1i #ex)) $
+ pair (log #s + #m)
+ (pair #ex #s))
+ (#tape :-> #d :->
+ map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape))
+ nil #vec
+ logsumexp v = inline logsumexp' (SNil .$ v)
+
+ mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $
+ let_ #hei (snd_ (fst_ (shape #mat))) $
+ let_ #wid (snd_ (shape #mat)) $
+ build1 #hei $ #i :->
+ idx0 (sum1i (build1 #wid $ #j :->
+ #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
+ m *@ v = inline mulmatvec (SNil .$ m .$ v)
+
+ subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $
+ build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i
+ a .- b = inline subvec (SNil .$ a .$ b)
+
+ matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $
+ build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j)
+ m .! i = inline matrow (SNil .$ m .$ i)
+
+ normsq' = lambda @(TVec R) #vec $ body $
+ idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x)))
+ normsq v = inline normsq' (SNil .$ v)
+
+ qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $
+ _
+ qmat q l = inline qmat' (SNil .$ q .$ l)
+ in - #k1
+ + idx0 (sum1i (build1 #N $ #i :->
+ logsumexp (build1 #K $ #k :->
+ #alpha ! pair nil #k
+ + idx0 (sum1i (#Q .! #k))
+ - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k))))))
+ - toFloat_ #N * logsumexp #alpha
+ + idx0 (sum1i (build1 #K $ #k :->
+ #k2 * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k))
+ - toFloat_ #m * idx0 (sum1i (#Q .! #k))))
+ - #k3
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 94d7423..cdfc1b1 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -52,29 +52,34 @@ library
test-suite example
type: exitcode-stdio-1.0
- main-is: example/Main.hs
+ main-is: Main.hs
build-depends: base, chad-fast
+ hs-source-dirs: example
default-language: Haskell2010
ghc-options: -Wall -threaded
test-suite test
type: exitcode-stdio-1.0
- main-is: test/Main.hs
+ main-is: Main.hs
build-depends:
chad-fast,
base,
dependent-map,
hedgehog,
+ hs-source-dirs: test
default-language: Haskell2010
ghc-options: -Wall -threaded
benchmark bench
type: exitcode-stdio-1.0
- main-is: bench/Main.hs
+ main-is: Main.hs
+ other-modules:
+ Bench.GMM
build-depends:
chad-fast,
base,
criterion,
deepseq,
+ hs-source-dirs: bench
default-language: Haskell2010
ghc-options: -Wall -threaded