From 013e01e28aba090c065ed584671a65aa339ea51b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 10 Nov 2024 12:39:08 +0100 Subject: Test GMM; it fails --- bench/Bench/GMM.hs | 114 --------------------------------------------------- chad-fast.cabal | 2 +- src/AST/Pretty.hs | 6 +-- src/CHAD.hs | 101 +++++++++++++++++++++++++++++++++++---------- src/Example/GMM.hs | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++++ test/Main.hs | 41 +++++++++++++++---- 6 files changed, 233 insertions(+), 148 deletions(-) delete mode 100644 bench/Bench/GMM.hs create mode 100644 src/Example/GMM.hs diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs deleted file mode 100644 index 9b84d23..0000000 --- a/bench/Bench/GMM.hs +++ /dev/null @@ -1,114 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE TypeApplications #-} -module Bench.GMM where - -import Language - - -type R = TScal TF64 -type I64 = TScal TI64 -type TVec = TArr (S Z) -type TMat = TArr (S (S Z)) - --- N, D, K: integers > 0 --- alpha, M, Q, L: the active parameters --- X: inactive data --- m: integer --- k1: 1/2 N D log(2 pi) --- k2: 1/2 gamma^2 --- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) --- where n' = D + m + 1 --- --- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. --- --- See: --- - "A benchmark of selected algorithmic differentiation tools on some problems --- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): --- 889-906 (2018). --- --- --- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. --- Master thesis at Utrecht University. (Appendix B.1) --- --- -objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -objective = fromNamed $ - lambda #N $ lambda #D $ lambda #K $ - lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ - lambda #X $ lambda #m $ - lambda #k1 $ lambda #k2 $ lambda #k3 $ - body $ - let -- We have: - -- sum (exp (x - max(x))) - -- = sum (exp x / exp (max(x))) - -- = sum (exp x) / exp (max(x)) - -- Hence: - -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) - -- - -- So: - -- d/dxi log (sum (exp x)) - -- = 1/(sum (exp x)) * d/dxi sum (exp x) - -- = 1/(sum (exp x)) * sum (d/dxi exp x) - -- = 1/(sum (exp x)) * exp xi - -- = exp xi / sum (exp x) - -- (by (*)) - -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) - -- = exp (xi - max(x)) / sum (exp (x - max(x))) - logsumexp' = lambda @(TVec R) #vec $ body $ - custom (#_ :-> #v :-> - let_ #m (idx0 (maximum1i #v)) $ - log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) - (#_ :-> #v :-> - let_ #m (idx0 (maximum1i #v)) $ - let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ - let_ #s (idx0 (sum1i #ex)) $ - pair (log #s + #m) - (pair #ex #s)) - (#tape :-> #d :-> - map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) - nil #vec - logsumexp v = inline logsumexp' (SNil .$ v) - - mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ - let_ #hei (snd_ (fst_ (shape #mat))) $ - let_ #wid (snd_ (shape #mat)) $ - build1 #hei $ #i :-> - idx0 (sum1i (build1 #wid $ #j :-> - #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) - m *@ v = inline mulmatvec (SNil .$ m .$ v) - - subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ - build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i - a .- b = inline subvec (SNil .$ a .$ b) - - matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ - build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) - m .! i = inline matrow (SNil .$ m .$ i) - - normsq' = lambda @(TVec R) #vec $ body $ - idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) - normsq v = inline normsq' (SNil .$ v) - - qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ - let_ #n (snd_ (shape #q)) $ - build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> - let_ #i (snd_ (fst_ #idx)) $ - let_ #j (snd_ #idx) $ - if_ (#i .== #j) - (exp (#q ! pair nil #i)) - (if_ (#i .> #j) - (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) - 0.0) - qmat q l = inline qmat' (SNil .$ q .$ l) - in - #k1 - + idx0 (sum1i (build1 #N $ #i :-> - logsumexp (build1 #K $ #k :-> - #alpha ! pair nil #k - + idx0 (sum1i (#Q .! #k)) - - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) - - toFloat_ #N * logsumexp #alpha - + idx0 (sum1i (build1 #K $ #k :-> - #k2 * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) - - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) - - #k3 diff --git a/chad-fast.cabal b/chad-fast.cabal index cdfc1b1..274e497 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -26,6 +26,7 @@ library Data Example Example.Format + Example.GMM ForwardAD ForwardAD.DualNumbers ForwardAD.DualNumbers.Types @@ -74,7 +75,6 @@ benchmark bench type: exitcode-stdio-1.0 main-is: Main.hs other-modules: - Bench.GMM build-depends: chad-fast, base, diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 51d89dc..ec8574f 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -197,9 +197,9 @@ ppExpr' d val = \case e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 return $ showParen (d > 10) $ showString "custom " - . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") " - . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") " - . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") " + . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") " + . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") " + . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") " . e1' . showString " " . e2' diff --git a/src/CHAD.hs b/src/CHAD.hs index 59d61a7..8b9f17a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -28,7 +28,6 @@ module CHAD ( Select, ) where -import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) import GHC.Stack (HasCallStack) @@ -198,6 +197,7 @@ type Storage :: Symbol -> Type data Storage s where SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions deriving instance Show (Storage s) -- | Environment description @@ -224,11 +224,13 @@ subDescr (des `DPush` (t, sto)) (SEYes sub) k = case sto of SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) subDescr (des `DPush` (_, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) -- | Select only the types from the environment that have the specified storage type family Select env sto s where @@ -240,8 +242,13 @@ select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) ---------------------------------- DERIVATIVES --------------------------------- @@ -338,13 +345,27 @@ conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t - -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - (Idx (Select env sto "merge") t) -conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ -conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) -conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} @@ -536,6 +557,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = (#acc :++: (#d :++: #shb :++: #tl))) SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + k (storepl `DPush` (t, SDiscr)) + envpro + (SENo prosub) + wf + -- Arrays with "merge" storage are promoted to an accumulator in envPro STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> k (storepl `DPush` (t, SAccum)) @@ -555,18 +583,40 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) - -- "merge" values must be an array, so reject everything else. (TODO: generalise this) + -- "merge" values must be an array or fully discrete, so reject everything else. (TODO: generalise this) _ -> - error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t - -- where - -- containsTArr :: STy t' -> Bool - -- containsTArr = \case - -- STNil -> False - -- STPair a b -> containsTArr a || containsTArr b - -- STEither a b -> containsTArr a || containsTArr b - -- STArr{} -> True - -- STScal{} -> False - -- STAccum{} -> error "An accumulator in merge storage?" + error $ "Closure variable of 'build'-like thing contains a non-array non-discrete SMerge value: " ++ show t + + -- Discrete values are left as-is, nothing to do + SDiscr -> + k (storepl `DPush` (t, SDiscr)) + envpro + prosub + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + -- containsTArr :: STy t' -> Bool + -- containsTArr = \case + -- STNil -> False + -- STPair a b -> containsTArr a || containsTArr b + -- STEither a b -> containsTArr a || containsTArr b + -- STArr{} -> True + -- STScal{} -> False + -- STAccum{} -> error "An accumulator in merge storage?" makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e @@ -682,20 +732,27 @@ drev :: forall env sto t. drev des = \case EVar _ t i -> case conv2Idx des i of - Left accI -> + Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) - Right tupI -> + Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (select SMerge des)) + (ENil ext) + ELet _ rhs body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs @@ -929,7 +986,7 @@ drev des = \case let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs new file mode 100644 index 0000000..ff37f9a --- /dev/null +++ b/src/Example/GMM.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE TypeApplications #-} +module Example.GMM where + +import Language + + +type R = TScal TF64 +type I64 = TScal TI64 +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) + +-- N, D, K: integers > 0 +-- alpha, M, Q, L: the active parameters +-- X: inactive data +-- m: integer +-- k1: 1/2 N D log(2 pi) +-- k2: 1/2 gamma^2 +-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) +-- where n' = D + m + 1 +-- +-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. +-- +-- See: +-- - "A benchmark of selected algorithmic differentiation tools on some problems +-- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): +-- 889-906 (2018). +-- +-- +-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. +-- Master thesis at Utrecht University. (Appendix B.1) +-- +-- +gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective = fromNamed $ + lambda #N $ lambda #D $ lambda #K $ + lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ + lambda #X $ lambda #m $ + lambda #k1 $ lambda #k2 $ lambda #k3 $ + body $ + let -- We have: + -- sum (exp (x - max(x))) + -- = sum (exp x / exp (max(x))) + -- = sum (exp x) / exp (max(x)) + -- Hence: + -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) + -- + -- So: + -- d/dxi log (sum (exp x)) + -- = 1/(sum (exp x)) * d/dxi sum (exp x) + -- = 1/(sum (exp x)) * sum (d/dxi exp x) + -- = 1/(sum (exp x)) * exp xi + -- = exp xi / sum (exp x) + -- (by (*)) + -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) + -- = exp (xi - max(x)) / sum (exp (x - max(x))) + logsumexp' = lambda @(TVec R) #vec $ body $ + let_ #m (maximum1i #vec) $ + log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m + -- custom (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) + -- (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ + -- let_ #s (idx0 (sum1i #ex)) $ + -- pair (log #s + #m) + -- (pair #ex #s)) + -- (#tape :-> #d :-> + -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) + -- nil #vec + logsumexp v = inline logsumexp' (SNil .$ v) + + mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ + let_ #hei (snd_ (fst_ (shape #mat))) $ + let_ #wid (snd_ (shape #mat)) $ + build1 #hei $ #i :-> + idx0 (sum1i (build1 #wid $ #j :-> + #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) + m *@ v = inline mulmatvec (SNil .$ m .$ v) + + subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ + build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i + a .- b = inline subvec (SNil .$ a .$ b) + + matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ + build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) + m .! i = inline matrow (SNil .$ m .$ i) + + normsq' = lambda @(TVec R) #vec $ body $ + idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) + normsq v = inline normsq' (SNil .$ v) + + qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ + let_ #n (snd_ (shape #q)) $ + build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> + let_ #i (snd_ (fst_ #idx)) $ + let_ #j (snd_ #idx) $ + if_ (#i .== #j) + (exp (#q ! pair nil #i)) + (if_ (#i .> #j) + (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) + 0.0) + qmat q l = inline qmat' (SNil .$ q .$ l) + in let_ #k2arr (unit #k2) $ + - #k1 + + idx0 (sum1i (build1 #N $ #i :-> + logsumexp (build1 #K $ #k :-> + #alpha ! pair nil #k + + idx0 (sum1i (#Q .! #k)) + - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) + - toFloat_ #N * logsumexp #alpha + + idx0 (sum1i (build1 #K $ #k :-> + idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) + - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) + - #k3 diff --git a/test/Main.hs b/test/Main.hs index 2573a32..75ab11a 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -10,6 +10,7 @@ module Main where import Data.Bifunctor +import Data.Int (Int64) import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen @@ -22,6 +23,7 @@ import AST.Pretty import CHAD.Top import CHAD.Types import qualified Example +import qualified Example.GMM as Example import ForwardAD import Interpreter import Interpreter.Rep @@ -150,14 +152,14 @@ adTestGen expr envGenerator = property $ do scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) - annotate (ppExpr knownEnv expr) - annotate ppdterm - annotate ppdterm_S - diff ppdterm_S20 (==) ppdterm_S - diff outChad closeIsh outChad_S - diff outPrimal closeIsh outChad_S - diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S - diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S + -- annotate (ppExpr knownEnv expr) + -- annotate ppdterm + -- annotate ppdterm_S + diff ppdterm_S (==) ppdterm_S20 + diff outChad_S closeIsh outChad + diff outChad_S closeIsh outPrimal + diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD + diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] @@ -221,6 +223,29 @@ tests = checkSequential $ Group "AD" lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + + ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do + -- The input ranges here are completely arbitrary. + let tR = STScal STF64 + kN <- Gen.integral (Range.linear 1 10) + kD <- Gen.integral (Range.linear 1 10) + kK <- Gen.integral (Range.linear 1 10) + let i2i64 = fromIntegral @Int @Int64 + valpha <- genArray tR (ShNil `ShCons` kK) + vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) + vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) + vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vm <- Gen.integral (Range.linear 0 5) + let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) + k2 = 0.5 * vgamma * vgamma + k3 = 0.42 -- don't feel like multigammaing today + return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` + Value vm `SCons` vX `SCons` + vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` + Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` + SNil)) ] main :: IO () -- cgit v1.2.3-70-g09d2