diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 22:59:30 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 22:59:30 +0100 |
commit | 33e0ed21603cbd85d6aba6548811db27480647db (patch) | |
tree | 10cedc6821e2ef78a5d35013cf3dcd7855ab04e1 | |
parent | d4d4473ee229674f73929c0860a7e29302330361 (diff) |
Most of GMM
-rw-r--r-- | bench/Bench/GMM.hs | 108 | ||||
-rw-r--r-- | chad-fast.cabal | 11 |
2 files changed, 116 insertions, 3 deletions
diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs new file mode 100644 index 0000000..ebbbe1e --- /dev/null +++ b/bench/Bench/GMM.hs @@ -0,0 +1,108 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE TypeApplications #-} +module Bench.GMM where + +import AST +import Data +import Language + + +type R = TScal TF64 +type I64 = TScal TI64 +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) + +-- N, D, K: integers > 0 +-- alpha, M, Q, L: the active parameters +-- X: inactive data +-- m: integer +-- k1: 1/2 N D log(2 pi) +-- k2: 1/2 gamma^2 +-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) +-- where n' = D + m + 1 +-- +-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. +-- +-- See: +-- - "A benchmark of selected algorithmic differentiation tools on some problems +-- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): +-- 889-906 (2018). +-- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651> +-- <https://github.com/microsoft/ADBench> +-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. +-- Master thesis at Utrecht University. +-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y> +-- <https://tomsmeding.com/f/master.pdf> +objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +objective = fromNamed $ + lambda #N $ lambda #D $ lambda #K $ + lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ + lambda #X $ lambda #m $ + lambda #k1 $ lambda #k2 $ lambda #k3 $ + body $ + let -- We have: + -- sum (exp (x - max(x))) + -- = sum (exp x / exp (max(x))) + -- = sum (exp x) / exp (max(x)) + -- Hence: + -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) + -- + -- So: + -- d/dxi log (sum (exp x)) + -- = 1/(sum (exp x)) * d/dxi sum (exp x) + -- = 1/(sum (exp x)) * sum (d/dxi exp x) + -- = 1/(sum (exp x)) * exp xi + -- = exp xi / sum (exp x) + -- (by (*)) + -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) + -- = exp (xi - max(x)) / sum (exp (x - max(x))) + logsumexp' = lambda @(TVec R) #vec $ body $ + custom (#_ :-> #v :-> + let_ #m (idx0 (maximum1i #v)) $ + log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) + (#_ :-> #v :-> + let_ #m (idx0 (maximum1i #v)) $ + let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ + let_ #s (idx0 (sum1i #ex)) $ + pair (log #s + #m) + (pair #ex #s)) + (#tape :-> #d :-> + map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) + nil #vec + logsumexp v = inline logsumexp' (SNil .$ v) + + mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ + let_ #hei (snd_ (fst_ (shape #mat))) $ + let_ #wid (snd_ (shape #mat)) $ + build1 #hei $ #i :-> + idx0 (sum1i (build1 #wid $ #j :-> + #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) + m *@ v = inline mulmatvec (SNil .$ m .$ v) + + subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ + build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i + a .- b = inline subvec (SNil .$ a .$ b) + + matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ + build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) + m .! i = inline matrow (SNil .$ m .$ i) + + normsq' = lambda @(TVec R) #vec $ body $ + idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) + normsq v = inline normsq' (SNil .$ v) + + qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ + _ + qmat q l = inline qmat' (SNil .$ q .$ l) + in - #k1 + + idx0 (sum1i (build1 #N $ #i :-> + logsumexp (build1 #K $ #k :-> + #alpha ! pair nil #k + + idx0 (sum1i (#Q .! #k)) + - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) + - toFloat_ #N * logsumexp #alpha + + idx0 (sum1i (build1 #K $ #k :-> + #k2 * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) + - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) + - #k3 diff --git a/chad-fast.cabal b/chad-fast.cabal index 94d7423..cdfc1b1 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -52,29 +52,34 @@ library test-suite example type: exitcode-stdio-1.0 - main-is: example/Main.hs + main-is: Main.hs build-depends: base, chad-fast + hs-source-dirs: example default-language: Haskell2010 ghc-options: -Wall -threaded test-suite test type: exitcode-stdio-1.0 - main-is: test/Main.hs + main-is: Main.hs build-depends: chad-fast, base, dependent-map, hedgehog, + hs-source-dirs: test default-language: Haskell2010 ghc-options: -Wall -threaded benchmark bench type: exitcode-stdio-1.0 - main-is: bench/Main.hs + main-is: Main.hs + other-modules: + Bench.GMM build-depends: chad-fast, base, criterion, deepseq, + hs-source-dirs: bench default-language: Haskell2010 ghc-options: -Wall -threaded |