diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
commit | 013e01e28aba090c065ed584671a65aa339ea51b (patch) | |
tree | 1595a8363fc181a13d41224e206d051d4e6a906b /src | |
parent | 9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff) |
Test GMM; it fails
Diffstat (limited to 'src')
-rw-r--r-- | src/AST/Pretty.hs | 6 | ||||
-rw-r--r-- | src/CHAD.hs | 101 | ||||
-rw-r--r-- | src/Example/GMM.hs | 117 |
3 files changed, 199 insertions, 25 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 51d89dc..ec8574f 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -197,9 +197,9 @@ ppExpr' d val = \case e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 return $ showParen (d > 10) $ showString "custom " - . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") " - . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") " - . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") " + . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") " + . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") " + . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") " . e1' . showString " " . e2' diff --git a/src/CHAD.hs b/src/CHAD.hs index 59d61a7..8b9f17a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -28,7 +28,6 @@ module CHAD ( Select, ) where -import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) import GHC.Stack (HasCallStack) @@ -198,6 +197,7 @@ type Storage :: Symbol -> Type data Storage s where SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions deriving instance Show (Storage s) -- | Environment description @@ -224,11 +224,13 @@ subDescr (des `DPush` (t, sto)) (SEYes sub) k = case sto of SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) subDescr (des `DPush` (_, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) -- | Select only the types from the environment that have the specified storage type family Select env sto s where @@ -240,8 +242,13 @@ select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) ---------------------------------- DERIVATIVES --------------------------------- @@ -338,13 +345,27 @@ conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t - -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - (Idx (Select env sto "merge") t) -conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ -conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) -conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} @@ -536,6 +557,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = (#acc :++: (#d :++: #shb :++: #tl))) SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + k (storepl `DPush` (t, SDiscr)) + envpro + (SENo prosub) + wf + -- Arrays with "merge" storage are promoted to an accumulator in envPro STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> k (storepl `DPush` (t, SAccum)) @@ -555,18 +583,40 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) - -- "merge" values must be an array, so reject everything else. (TODO: generalise this) + -- "merge" values must be an array or fully discrete, so reject everything else. (TODO: generalise this) _ -> - error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t - -- where - -- containsTArr :: STy t' -> Bool - -- containsTArr = \case - -- STNil -> False - -- STPair a b -> containsTArr a || containsTArr b - -- STEither a b -> containsTArr a || containsTArr b - -- STArr{} -> True - -- STScal{} -> False - -- STAccum{} -> error "An accumulator in merge storage?" + error $ "Closure variable of 'build'-like thing contains a non-array non-discrete SMerge value: " ++ show t + + -- Discrete values are left as-is, nothing to do + SDiscr -> + k (storepl `DPush` (t, SDiscr)) + envpro + prosub + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + -- containsTArr :: STy t' -> Bool + -- containsTArr = \case + -- STNil -> False + -- STPair a b -> containsTArr a || containsTArr b + -- STEither a b -> containsTArr a || containsTArr b + -- STArr{} -> True + -- STScal{} -> False + -- STAccum{} -> error "An accumulator in merge storage?" makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e @@ -682,20 +732,27 @@ drev :: forall env sto t. drev des = \case EVar _ t i -> case conv2Idx des i of - Left accI -> + Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) - Right tupI -> + Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (select SMerge des)) + (ENil ext) + ELet _ rhs body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs @@ -929,7 +986,7 @@ drev des = \case let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs new file mode 100644 index 0000000..ff37f9a --- /dev/null +++ b/src/Example/GMM.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE TypeApplications #-} +module Example.GMM where + +import Language + + +type R = TScal TF64 +type I64 = TScal TI64 +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) + +-- N, D, K: integers > 0 +-- alpha, M, Q, L: the active parameters +-- X: inactive data +-- m: integer +-- k1: 1/2 N D log(2 pi) +-- k2: 1/2 gamma^2 +-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) +-- where n' = D + m + 1 +-- +-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. +-- +-- See: +-- - "A benchmark of selected algorithmic differentiation tools on some problems +-- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): +-- 889-906 (2018). +-- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651> +-- <https://github.com/microsoft/ADBench> +-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. +-- Master thesis at Utrecht University. (Appendix B.1) +-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y> +-- <https://tomsmeding.com/f/master.pdf> +gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective = fromNamed $ + lambda #N $ lambda #D $ lambda #K $ + lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ + lambda #X $ lambda #m $ + lambda #k1 $ lambda #k2 $ lambda #k3 $ + body $ + let -- We have: + -- sum (exp (x - max(x))) + -- = sum (exp x / exp (max(x))) + -- = sum (exp x) / exp (max(x)) + -- Hence: + -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) + -- + -- So: + -- d/dxi log (sum (exp x)) + -- = 1/(sum (exp x)) * d/dxi sum (exp x) + -- = 1/(sum (exp x)) * sum (d/dxi exp x) + -- = 1/(sum (exp x)) * exp xi + -- = exp xi / sum (exp x) + -- (by (*)) + -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) + -- = exp (xi - max(x)) / sum (exp (x - max(x))) + logsumexp' = lambda @(TVec R) #vec $ body $ + let_ #m (maximum1i #vec) $ + log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m + -- custom (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) + -- (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ + -- let_ #s (idx0 (sum1i #ex)) $ + -- pair (log #s + #m) + -- (pair #ex #s)) + -- (#tape :-> #d :-> + -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) + -- nil #vec + logsumexp v = inline logsumexp' (SNil .$ v) + + mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ + let_ #hei (snd_ (fst_ (shape #mat))) $ + let_ #wid (snd_ (shape #mat)) $ + build1 #hei $ #i :-> + idx0 (sum1i (build1 #wid $ #j :-> + #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) + m *@ v = inline mulmatvec (SNil .$ m .$ v) + + subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ + build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i + a .- b = inline subvec (SNil .$ a .$ b) + + matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ + build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) + m .! i = inline matrow (SNil .$ m .$ i) + + normsq' = lambda @(TVec R) #vec $ body $ + idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) + normsq v = inline normsq' (SNil .$ v) + + qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ + let_ #n (snd_ (shape #q)) $ + build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> + let_ #i (snd_ (fst_ #idx)) $ + let_ #j (snd_ #idx) $ + if_ (#i .== #j) + (exp (#q ! pair nil #i)) + (if_ (#i .> #j) + (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) + 0.0) + qmat q l = inline qmat' (SNil .$ q .$ l) + in let_ #k2arr (unit #k2) $ + - #k1 + + idx0 (sum1i (build1 #N $ #i :-> + logsumexp (build1 #K $ #k :-> + #alpha ! pair nil #k + + idx0 (sum1i (#Q .! #k)) + - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) + - toFloat_ #N * logsumexp #alpha + + idx0 (sum1i (build1 #K $ #k :-> + idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) + - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) + - #k3 |