summaryrefslogtreecommitdiff
path: root/bench
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-10 10:04:27 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-10 10:04:27 +0100
commit42d59947356ab51e5a4070b930f02f4909208d35 (patch)
tree3c8afab888e61c4e3157a257f0a40ae2fd4eb9c1 /bench
parent33e0ed21603cbd85d6aba6548811db27480647db (diff)
Complete GMM implementation
Diffstat (limited to 'bench')
-rw-r--r--bench/Bench/GMM.hs14
1 files changed, 10 insertions, 4 deletions
diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs
index ebbbe1e..9b84d23 100644
--- a/bench/Bench/GMM.hs
+++ b/bench/Bench/GMM.hs
@@ -3,8 +3,6 @@
{-# LANGUAGE TypeApplications #-}
module Bench.GMM where
-import AST
-import Data
import Language
@@ -31,7 +29,7 @@ type TMat = TArr (S (S Z))
-- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651>
-- <https://github.com/microsoft/ADBench>
-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”.
--- Master thesis at Utrecht University.
+-- Master thesis at Utrecht University. (Appendix B.1)
-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
-- <https://tomsmeding.com/f/master.pdf>
objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
@@ -93,7 +91,15 @@ objective = fromNamed $
normsq v = inline normsq' (SNil .$ v)
qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $
- _
+ let_ #n (snd_ (shape #q)) $
+ build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :->
+ let_ #i (snd_ (fst_ #idx)) $
+ let_ #j (snd_ #idx) $
+ if_ (#i .== #j)
+ (exp (#q ! pair nil #i))
+ (if_ (#i .> #j)
+ (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j)
+ 0.0)
qmat q l = inline qmat' (SNil .$ q .$ l)
in - #k1
+ idx0 (sum1i (build1 #N $ #i :->