diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 10:04:27 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 10:04:27 +0100 | 
| commit | 42d59947356ab51e5a4070b930f02f4909208d35 (patch) | |
| tree | 3c8afab888e61c4e3157a257f0a40ae2fd4eb9c1 /bench | |
| parent | 33e0ed21603cbd85d6aba6548811db27480647db (diff) | |
Complete GMM implementation
Diffstat (limited to 'bench')
| -rw-r--r-- | bench/Bench/GMM.hs | 14 | 
1 files changed, 10 insertions, 4 deletions
| diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs index ebbbe1e..9b84d23 100644 --- a/bench/Bench/GMM.hs +++ b/bench/Bench/GMM.hs @@ -3,8 +3,6 @@  {-# LANGUAGE TypeApplications #-}  module Bench.GMM where -import AST -import Data  import Language @@ -31,7 +29,7 @@ type TMat = TArr (S (S Z))  --   <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651>  --   <https://github.com/microsoft/ADBench>  -- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. ---   Master thesis at Utrecht University. +--   Master thesis at Utrecht University. (Appendix B.1)  --   <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>  --   <https://tomsmeding.com/f/master.pdf>  objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R @@ -93,7 +91,15 @@ objective = fromNamed $          normsq v = inline normsq' (SNil .$ v)          qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ -                  _ +                  let_ #n (snd_ (shape #q)) $ +                    build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> +                      let_ #i (snd_ (fst_ #idx)) $ +                      let_ #j (snd_ #idx) $ +                        if_ (#i .== #j) +                          (exp (#q ! pair nil #i)) +                          (if_ (#i .> #j) +                            (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) +                            0.0)          qmat q l = inline qmat' (SNil .$ q .$ l)      in - #k1         + idx0 (sum1i (build1 #N $ #i :-> | 
