From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/Example/GMM.hs | 123 --------------------------------------------------- src/Example/Types.hs | 11 ----- 2 files changed, 134 deletions(-) delete mode 100644 src/Example/GMM.hs delete mode 100644 src/Example/Types.hs (limited to 'src/Example') diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs deleted file mode 100644 index 206e534..0000000 --- a/src/Example/GMM.hs +++ /dev/null @@ -1,123 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE TypeApplications #-} -module Example.GMM where - -import Example.Types -import Language - - - --- N, D, K: integers > 0 --- alpha, M, Q, L: the active parameters --- X: inactive data --- m: integer --- k1: 1/2 N D log(2 pi) --- k2: 1/2 gamma^2 --- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) --- where n' = D + m + 1 --- --- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. --- --- See: --- - "A benchmark of selected algorithmic differentiation tools on some problems --- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): --- 889-906 (2018). --- --- --- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. --- Master thesis at Utrecht University. (Appendix B.1) --- --- --- --- The 'wrong' argument, when set to True, changes the objective function to --- one with a bug that makes a certain `build` result unused. This --- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's --- dense, even though it may be a zero (i.e. empty). The "unused" test in --- test/Main.hs tries to isolate this case, but the wrong version of --- gmmObjective is here to check (after that bug is fixed) whether it really --- fixes the original bug. -gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -gmmObjective wrong = fromNamed $ - lambda #N $ lambda #D $ lambda #K $ - lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ - lambda #X $ lambda #m $ - lambda #k1 $ lambda #k2 $ lambda #k3 $ - body $ - let -- We have: - -- sum (exp (x - max(x))) - -- = sum (exp x / exp (max(x))) - -- = sum (exp x) / exp (max(x)) - -- Hence: - -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) - -- - -- So: - -- d/dxi log (sum (exp x)) - -- = 1/(sum (exp x)) * d/dxi sum (exp x) - -- = 1/(sum (exp x)) * sum (d/dxi exp x) - -- = 1/(sum (exp x)) * exp xi - -- = exp xi / sum (exp x) - -- (by (*)) - -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) - -- = exp (xi - max(x)) / sum (exp (x - max(x))) - logsumexp' = lambda @(TVec R) #vec $ body $ - let_ #m (maximum1i #vec) $ - log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m - -- custom (#_ :-> #v :-> - -- let_ #m (idx0 (maximum1i #v)) $ - -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) - -- (#_ :-> #v :-> - -- let_ #m (idx0 (maximum1i #v)) $ - -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ - -- let_ #s (idx0 (sum1i #ex)) $ - -- pair (log #s + #m) - -- (pair #ex #s)) - -- (#tape :-> #d :-> - -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) - -- nil #vec - logsumexp v = inline logsumexp' (SNil .$ v) - - mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ - let_ #hei (snd_ (fst_ (shape #mat))) $ - let_ #wid (snd_ (shape #mat)) $ - build1 #hei $ #i :-> - idx0 (sum1i (build1 #wid $ #j :-> - #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) - m *@ v = inline mulmatvec (SNil .$ m .$ v) - - subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ - build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i - a .- b = inline subvec (SNil .$ a .$ b) - - matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ - build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) - m .! i = inline matrow (SNil .$ m .$ i) - - normsq' = lambda @(TVec R) #vec $ body $ - idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) - normsq v = inline normsq' (SNil .$ v) - - qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ - let_ #n (snd_ (shape #q)) $ - build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> - let_ #i (snd_ (fst_ #idx)) $ - let_ #j (snd_ #idx) $ - if_ (#i .== #j) - (exp (#q ! pair nil #i)) - (if_ (#i .> #j) - (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) - else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j)) - 0.0) - qmat q l = inline qmat' (SNil .$ q .$ l) - in let_ #k2arr (unit #k2) $ - - #k1 - + idx0 (sum1i (build1 #N $ #i :-> - logsumexp (build1 #K $ #k :-> - #alpha ! pair nil #k - + idx0 (sum1i (#Q .! #k)) - - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) - - toFloat_ #N * logsumexp #alpha - + idx0 (sum1i (build1 #K $ #k :-> - idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) - - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) - - #k3 diff --git a/src/Example/Types.hs b/src/Example/Types.hs deleted file mode 100644 index d63159b..0000000 --- a/src/Example/Types.hs +++ /dev/null @@ -1,11 +0,0 @@ -{-# LANGUAGE DataKinds #-} -module Example.Types where - -import AST -import Data - - -type R = TScal TF64 -type I64 = TScal TI64 -type TVec = TArr (S Z) -type TMat = TArr (S (S Z)) -- cgit v1.2.3-70-g09d2