summaryrefslogtreecommitdiff
path: root/src/Example/GMM.hs
blob: 12bbd98b56bb2ab7fe4f6dab380f489a21c009ba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE TypeApplications #-}
module Example.GMM where

import Example.Types
import Language



-- N, D, K: integers > 0
-- alpha, M, Q, L: the active parameters
-- X: inactive data
-- m: integer
-- k1: 1/2 N D log(2 pi)
-- k2: 1/2 gamma^2
-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D))
--       where n' = D + m + 1
--
-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m.
--
-- See:
-- - "A benchmark of selected algorithmic differentiation tools on some problems
--   in computer vision and machine learning". Optim. Methods Softw. 33(4-6):
--   889-906 (2018).
--   <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651>
--   <https://github.com/microsoft/ADBench>
-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”.
--   Master thesis at Utrecht University. (Appendix B.1)
--   <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
--   <https://tomsmeding.com/f/master.pdf>
--
-- The 'wrong' argument, when set to True, changes the objective function to
-- one with a bug that makes a certain `build` result unused. This triggers
-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
-- dense, even though it may be a zero (i.e. empty). The "unused" test in
-- test/Main.hs tries to isolate this test, but the wrong version of
-- gmmObjective is here to check (after that bug is fixed) whether it really
-- fixes the original bug.
gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
gmmObjective wrong = fromNamed $
  lambda #N $ lambda #D $ lambda #K $
  lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $
  lambda #X $ lambda #m $
  lambda #k1 $ lambda #k2 $ lambda #k3 $
  body $
    let -- We have:
        -- sum (exp (x - max(x)))
        -- = sum (exp x / exp (max(x)))
        -- = sum (exp x) / exp (max(x))
        -- Hence:
        -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x))    (*)
        --
        -- So:
        -- d/dxi log (sum (exp x))
        -- = 1/(sum (exp x)) * d/dxi sum (exp x)
        -- = 1/(sum (exp x)) * sum (d/dxi exp x)
        -- = 1/(sum (exp x)) * exp xi
        -- = exp xi / sum (exp x)
        -- (by (*))
        -- = exp xi / (sum (exp (x - max(x))) * exp (max(x)))
        -- = exp (xi - max(x)) / sum (exp (x - max(x)))
        logsumexp' = lambda @(TVec R) #vec $ body $
                       let_ #m (maximum1i #vec) $
                         log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
                       -- custom (#_ :-> #v :->
                       --           let_ #m (idx0 (maximum1i #v)) $
                       --             log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m)
                       --        (#_ :-> #v :->
                       --           let_ #m (idx0 (maximum1i #v)) $
                       --           let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $
                       --           let_ #s (idx0 (sum1i #ex)) $
                       --             pair (log #s + #m)
                       --                  (pair #ex #s))
                       --        (#tape :-> #d :->
                       --           map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape))
                       --        nil #vec
        logsumexp v = inline logsumexp' (SNil .$ v)

        mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $
                      let_ #hei (snd_ (fst_ (shape #mat))) $
                      let_ #wid (snd_ (shape #mat)) $
                        build1 #hei $ #i :->
                          idx0 (sum1i (build1 #wid $ #j :->
                                         #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
        m *@ v = inline mulmatvec (SNil .$ m .$ v)

        subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $
                    build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i
        a .- b = inline subvec (SNil .$ a .$ b)

        matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $
                   build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j)
        m .! i = inline matrow (SNil .$ m .$ i)

        normsq' = lambda @(TVec R) #vec $ body $
                    idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x)))
        normsq v = inline normsq' (SNil .$ v)

        qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $
                  let_ #n (snd_ (shape #q)) $
                    build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :->
                      let_ #i (snd_ (fst_ #idx)) $
                      let_ #j (snd_ #idx) $
                        if_ (#i .== #j)
                          (exp (#q ! pair nil #i))
                          (if_ (#i .> #j)
                            (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j)
                                      else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j))
                            0.0)
        qmat q l = inline qmat' (SNil .$ q .$ l)
    in let_ #k2arr (unit #k2) $
       - #k1
       + idx0 (sum1i (build1 #N $ #i :->
             logsumexp (build1 #K $ #k :->
                 #alpha ! pair nil #k
                 + idx0 (sum1i (#Q .! #k))
                 - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k))))))
       - toFloat_ #N * logsumexp #alpha
       + idx0 (sum1i (build1 #K $ #k :->
                        idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k))
                        - toFloat_ #m * idx0 (sum1i (#Q .! #k))))
       - #k3