diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
commit | 013e01e28aba090c065ed584671a65aa339ea51b (patch) | |
tree | 1595a8363fc181a13d41224e206d051d4e6a906b | |
parent | 9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff) |
Test GMM; it fails
-rw-r--r-- | chad-fast.cabal | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 6 | ||||
-rw-r--r-- | src/CHAD.hs | 101 | ||||
-rw-r--r-- | src/Example/GMM.hs (renamed from bench/Bench/GMM.hs) | 37 | ||||
-rw-r--r-- | test/Main.hs | 41 |
5 files changed, 136 insertions, 51 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index cdfc1b1..274e497 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -26,6 +26,7 @@ library Data Example Example.Format + Example.GMM ForwardAD ForwardAD.DualNumbers ForwardAD.DualNumbers.Types @@ -74,7 +75,6 @@ benchmark bench type: exitcode-stdio-1.0 main-is: Main.hs other-modules: - Bench.GMM build-depends: chad-fast, base, diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 51d89dc..ec8574f 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -197,9 +197,9 @@ ppExpr' d val = \case e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 return $ showParen (d > 10) $ showString "custom " - . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") " - . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") " - . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") " + . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") " + . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") " + . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") " . e1' . showString " " . e2' diff --git a/src/CHAD.hs b/src/CHAD.hs index 59d61a7..8b9f17a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -28,7 +28,6 @@ module CHAD ( Select, ) where -import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) import GHC.Stack (HasCallStack) @@ -198,6 +197,7 @@ type Storage :: Symbol -> Type data Storage s where SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions deriving instance Show (Storage s) -- | Environment description @@ -224,11 +224,13 @@ subDescr (des `DPush` (t, sto)) (SEYes sub) k = case sto of SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) subDescr (des `DPush` (_, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) -- | Select only the types from the environment that have the specified storage type family Select env sto s where @@ -240,8 +242,13 @@ select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) ---------------------------------- DERIVATIVES --------------------------------- @@ -338,13 +345,27 @@ conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t - -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - (Idx (Select env sto "merge") t) -conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ -conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) -conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} @@ -536,6 +557,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = (#acc :++: (#d :++: #shb :++: #tl))) SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + k (storepl `DPush` (t, SDiscr)) + envpro + (SENo prosub) + wf + -- Arrays with "merge" storage are promoted to an accumulator in envPro STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> k (storepl `DPush` (t, SAccum)) @@ -555,18 +583,40 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) - -- "merge" values must be an array, so reject everything else. (TODO: generalise this) + -- "merge" values must be an array or fully discrete, so reject everything else. (TODO: generalise this) _ -> - error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t - -- where - -- containsTArr :: STy t' -> Bool - -- containsTArr = \case - -- STNil -> False - -- STPair a b -> containsTArr a || containsTArr b - -- STEither a b -> containsTArr a || containsTArr b - -- STArr{} -> True - -- STScal{} -> False - -- STAccum{} -> error "An accumulator in merge storage?" + error $ "Closure variable of 'build'-like thing contains a non-array non-discrete SMerge value: " ++ show t + + -- Discrete values are left as-is, nothing to do + SDiscr -> + k (storepl `DPush` (t, SDiscr)) + envpro + prosub + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + -- containsTArr :: STy t' -> Bool + -- containsTArr = \case + -- STNil -> False + -- STPair a b -> containsTArr a || containsTArr b + -- STEither a b -> containsTArr a || containsTArr b + -- STArr{} -> True + -- STScal{} -> False + -- STAccum{} -> error "An accumulator in merge storage?" makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e @@ -682,20 +732,27 @@ drev :: forall env sto t. drev des = \case EVar _ t i -> case conv2Idx des i of - Left accI -> + Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) - Right tupI -> + Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (select SMerge des)) + (ENil ext) + ELet _ rhs body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs @@ -929,7 +986,7 @@ drev des = \case let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in diff --git a/bench/Bench/GMM.hs b/src/Example/GMM.hs index 9b84d23..ff37f9a 100644 --- a/bench/Bench/GMM.hs +++ b/src/Example/GMM.hs @@ -1,7 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeApplications #-} -module Bench.GMM where +module Example.GMM where import Language @@ -32,8 +32,8 @@ type TMat = TArr (S (S Z)) -- Master thesis at Utrecht University. (Appendix B.1) -- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y> -- <https://tomsmeding.com/f/master.pdf> -objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -objective = fromNamed $ +gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective = fromNamed $ lambda #N $ lambda #D $ lambda #K $ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ lambda #X $ lambda #m $ @@ -56,18 +56,20 @@ objective = fromNamed $ -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) -- = exp (xi - max(x)) / sum (exp (x - max(x))) logsumexp' = lambda @(TVec R) #vec $ body $ - custom (#_ :-> #v :-> - let_ #m (idx0 (maximum1i #v)) $ - log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) - (#_ :-> #v :-> - let_ #m (idx0 (maximum1i #v)) $ - let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ - let_ #s (idx0 (sum1i #ex)) $ - pair (log #s + #m) - (pair #ex #s)) - (#tape :-> #d :-> - map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) - nil #vec + let_ #m (maximum1i #vec) $ + log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m + -- custom (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) + -- (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ + -- let_ #s (idx0 (sum1i #ex)) $ + -- pair (log #s + #m) + -- (pair #ex #s)) + -- (#tape :-> #d :-> + -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) + -- nil #vec logsumexp v = inline logsumexp' (SNil .$ v) mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ @@ -101,7 +103,8 @@ objective = fromNamed $ (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) 0.0) qmat q l = inline qmat' (SNil .$ q .$ l) - in - #k1 + in let_ #k2arr (unit #k2) $ + - #k1 + idx0 (sum1i (build1 #N $ #i :-> logsumexp (build1 #K $ #k :-> #alpha ! pair nil #k @@ -109,6 +112,6 @@ objective = fromNamed $ - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) - toFloat_ #N * logsumexp #alpha + idx0 (sum1i (build1 #K $ #k :-> - #k2 * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) + idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) - #k3 diff --git a/test/Main.hs b/test/Main.hs index 2573a32..75ab11a 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -10,6 +10,7 @@ module Main where import Data.Bifunctor +import Data.Int (Int64) import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen @@ -22,6 +23,7 @@ import AST.Pretty import CHAD.Top import CHAD.Types import qualified Example +import qualified Example.GMM as Example import ForwardAD import Interpreter import Interpreter.Rep @@ -150,14 +152,14 @@ adTestGen expr envGenerator = property $ do scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) - annotate (ppExpr knownEnv expr) - annotate ppdterm - annotate ppdterm_S - diff ppdterm_S20 (==) ppdterm_S - diff outChad closeIsh outChad_S - diff outPrimal closeIsh outChad_S - diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S - diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S + -- annotate (ppExpr knownEnv expr) + -- annotate ppdterm + -- annotate ppdterm_S + diff ppdterm_S (==) ppdterm_S20 + diff outChad_S closeIsh outChad + diff outChad_S closeIsh outPrimal + diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD + diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] @@ -221,6 +223,29 @@ tests = checkSequential $ Group "AD" lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + + ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do + -- The input ranges here are completely arbitrary. + let tR = STScal STF64 + kN <- Gen.integral (Range.linear 1 10) + kD <- Gen.integral (Range.linear 1 10) + kK <- Gen.integral (Range.linear 1 10) + let i2i64 = fromIntegral @Int @Int64 + valpha <- genArray tR (ShNil `ShCons` kK) + vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) + vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) + vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vm <- Gen.integral (Range.linear 0 5) + let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) + k2 = 0.5 * vgamma * vgamma + k3 = 0.42 -- don't feel like multigammaing today + return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` + Value vm `SCons` vX `SCons` + vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` + Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` + SNil)) ] main :: IO () |