{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeApplications #-} module Bench.GMM where import AST import Data import Language type R = TScal TF64 type I64 = TScal TI64 type TVec = TArr (S Z) type TMat = TArr (S (S Z)) -- N, D, K: integers > 0 -- alpha, M, Q, L: the active parameters -- X: inactive data -- m: integer -- k1: 1/2 N D log(2 pi) -- k2: 1/2 gamma^2 -- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) -- where n' = D + m + 1 -- -- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. -- -- See: -- - "A benchmark of selected algorithmic differentiation tools on some problems -- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): -- 889-906 (2018). -- -- -- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. -- Master thesis at Utrecht University. -- -- objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R objective = fromNamed $ lambda #N $ lambda #D $ lambda #K $ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ lambda #X $ lambda #m $ lambda #k1 $ lambda #k2 $ lambda #k3 $ body $ let -- We have: -- sum (exp (x - max(x))) -- = sum (exp x / exp (max(x))) -- = sum (exp x) / exp (max(x)) -- Hence: -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) -- -- So: -- d/dxi log (sum (exp x)) -- = 1/(sum (exp x)) * d/dxi sum (exp x) -- = 1/(sum (exp x)) * sum (d/dxi exp x) -- = 1/(sum (exp x)) * exp xi -- = exp xi / sum (exp x) -- (by (*)) -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) -- = exp (xi - max(x)) / sum (exp (x - max(x))) logsumexp' = lambda @(TVec R) #vec $ body $ custom (#_ :-> #v :-> let_ #m (idx0 (maximum1i #v)) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) (#_ :-> #v :-> let_ #m (idx0 (maximum1i #v)) $ let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ let_ #s (idx0 (sum1i #ex)) $ pair (log #s + #m) (pair #ex #s)) (#tape :-> #d :-> map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) nil #vec logsumexp v = inline logsumexp' (SNil .$ v) mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ let_ #hei (snd_ (fst_ (shape #mat))) $ let_ #wid (snd_ (shape #mat)) $ build1 #hei $ #i :-> idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) m *@ v = inline mulmatvec (SNil .$ m .$ v) subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i a .- b = inline subvec (SNil .$ a .$ b) matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) m .! i = inline matrow (SNil .$ m .$ i) normsq' = lambda @(TVec R) #vec $ body $ idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) normsq v = inline normsq' (SNil .$ v) qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ _ qmat q l = inline qmat' (SNil .$ q .$ l) in - #k1 + idx0 (sum1i (build1 #N $ #i :-> logsumexp (build1 #K $ #k :-> #alpha ! pair nil #k + idx0 (sum1i (#Q .! #k)) - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) - toFloat_ #N * logsumexp #alpha + idx0 (sum1i (build1 #K $ #k :-> #k2 * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) - #k3