summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-10 12:39:08 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-10 12:39:08 +0100
commit013e01e28aba090c065ed584671a65aa339ea51b (patch)
tree1595a8363fc181a13d41224e206d051d4e6a906b
parent9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff)
Test GMM; it fails
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/AST/Pretty.hs6
-rw-r--r--src/CHAD.hs101
-rw-r--r--src/Example/GMM.hs (renamed from bench/Bench/GMM.hs)37
-rw-r--r--test/Main.hs41
5 files changed, 136 insertions, 51 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index cdfc1b1..274e497 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -26,6 +26,7 @@ library
Data
Example
Example.Format
+ Example.GMM
ForwardAD
ForwardAD.DualNumbers
ForwardAD.DualNumbers.Types
@@ -74,7 +75,6 @@ benchmark bench
type: exitcode-stdio-1.0
main-is: Main.hs
other-modules:
- Bench.GMM
build-depends:
chad-fast,
base,
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 51d89dc..ec8574f 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -197,9 +197,9 @@ ppExpr' d val = \case
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
return $ showParen (d > 10) $ showString "custom "
- . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") "
- . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") "
- . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") "
+ . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") "
+ . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") "
+ . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") "
. e1' . showString " "
. e2'
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 59d61a7..8b9f17a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -28,7 +28,6 @@ module CHAD (
Select,
) where
-import Data.Bifunctor (first, second)
import Data.Functor.Const
import Data.Kind (Type)
import GHC.Stack (HasCallStack)
@@ -198,6 +197,7 @@ type Storage :: Symbol -> Type
data Storage s where
SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
SMerge :: Storage "merge" -- ^ just return and merge
+ SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions
deriving instance Show (Storage s)
-- | Environment description
@@ -224,11 +224,13 @@ subDescr (des `DPush` (t, sto)) (SEYes sub) k =
case sto of
SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+ SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e)
subDescr (des `DPush` (_, sto)) (SENo sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+ SDiscr -> k des' submerge subaccum (SENo subd1e)
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
@@ -240,8 +242,13 @@ select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
select _ DTop = SNil
select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
select s@SMerge (DPush des (_, SAccum)) = select s des
+select s@SDiscr (DPush des (_, SAccum)) = select s des
select s@SAccum (DPush des (_, SMerge)) = select s des
select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+select s@SDiscr (DPush des (_, SMerge)) = select s des
+select s@SAccum (DPush des (_, SDiscr)) = select s des
+select s@SMerge (DPush des (_, SDiscr)) = select s des
+select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des)
---------------------------------- DERIVATIVES ---------------------------------
@@ -338,13 +345,27 @@ conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
-conv2Idx :: Descr env sto -> Idx env t
- -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
- (Idx (Select env sto "merge") t)
-conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ
-conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ
-conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)
-conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)
+data Idx2 env sto t
+ = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ | Idx2Me (Idx (Select env sto "merge") t)
+ | Idx2Di (Idx (Select env sto "discr") t)
+
+conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
+conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ
+conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ
+conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ
+conv2Idx (DPush des (_, SAccum)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)
+ Idx2Me j -> Idx2Me j
+ Idx2Di j -> Idx2Di j
+conv2Idx (DPush des (_, SMerge)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac j
+ Idx2Me j -> Idx2Me (IS j)
+ Idx2Di j -> Idx2Di j
+conv2Idx (DPush des (_, SDiscr)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac j
+ Idx2Me j -> Idx2Me j
+ Idx2Di j -> Idx2Di (IS j)
conv2Idx DTop i = case i of {}
@@ -536,6 +557,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
(#acc :++: (#d :++: #shb :++: #tl)))
SMerge -> case t of
+ -- Discrete values are left as-is
+ _ | isDiscrete t ->
+ k (storepl `DPush` (t, SDiscr))
+ envpro
+ (SENo prosub)
+ wf
+
-- Arrays with "merge" storage are promoted to an accumulator in envPro
STArr (arrn :: SNat arrn) (arrt :: STy arrt) ->
k (storepl `DPush` (t, SAccum))
@@ -555,18 +583,40 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
.> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
- -- "merge" values must be an array, so reject everything else. (TODO: generalise this)
+ -- "merge" values must be an array or fully discrete, so reject everything else. (TODO: generalise this)
_ ->
- error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t
- -- where
- -- containsTArr :: STy t' -> Bool
- -- containsTArr = \case
- -- STNil -> False
- -- STPair a b -> containsTArr a || containsTArr b
- -- STEither a b -> containsTArr a || containsTArr b
- -- STArr{} -> True
- -- STScal{} -> False
- -- STAccum{} -> error "An accumulator in merge storage?"
+ error $ "Closure variable of 'build'-like thing contains a non-array non-discrete SMerge value: " ++ show t
+
+ -- Discrete values are left as-is, nothing to do
+ SDiscr ->
+ k (storepl `DPush` (t, SDiscr))
+ envpro
+ prosub
+ wf
+ where
+ isDiscrete :: STy t' -> Bool
+ isDiscrete = \case
+ STNil -> True
+ STPair a b -> isDiscrete a && isDiscrete b
+ STEither a b -> isDiscrete a && isDiscrete b
+ STMaybe a -> isDiscrete a
+ STArr _ a -> isDiscrete a
+ STScal st -> case st of
+ STI32 -> True
+ STI64 -> True
+ STF32 -> False
+ STF64 -> False
+ STBool -> True
+ STAccum{} -> False
+
+ -- containsTArr :: STy t' -> Bool
+ -- containsTArr = \case
+ -- STNil -> False
+ -- STPair a b -> containsTArr a || containsTArr b
+ -- STEither a b -> containsTArr a || containsTArr b
+ -- STArr{} -> True
+ -- STScal{} -> False
+ -- STAccum{} -> error "An accumulator in merge storage?"
makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators SNil e = e
@@ -682,20 +732,27 @@ drev :: forall env sto t.
drev des = \case
EVar _ t i ->
case conv2Idx des i of
- Left accI ->
+ Idx2Ac accI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
(EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
- Right tupI ->
+ Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvOnehot (select SMerge des) tupI)
(EPair ext (ENil ext) (EVar ext (d2 t) IZ))
+ Idx2Di _ ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvNone (select SMerge des))
+ (ENil ext)
+
ELet _ rhs body
| Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
<- drev des rhs
@@ -929,7 +986,7 @@ drev des = \case
let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
- case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
let collectexpr = bindingsCollect e0 subtapeE in
diff --git a/bench/Bench/GMM.hs b/src/Example/GMM.hs
index 9b84d23..ff37f9a 100644
--- a/bench/Bench/GMM.hs
+++ b/src/Example/GMM.hs
@@ -1,7 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE TypeApplications #-}
-module Bench.GMM where
+module Example.GMM where
import Language
@@ -32,8 +32,8 @@ type TMat = TArr (S (S Z))
-- Master thesis at Utrecht University. (Appendix B.1)
-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
-- <https://tomsmeding.com/f/master.pdf>
-objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
-objective = fromNamed $
+gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
+gmmObjective = fromNamed $
lambda #N $ lambda #D $ lambda #K $
lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $
lambda #X $ lambda #m $
@@ -56,18 +56,20 @@ objective = fromNamed $
-- = exp xi / (sum (exp (x - max(x))) * exp (max(x)))
-- = exp (xi - max(x)) / sum (exp (x - max(x)))
logsumexp' = lambda @(TVec R) #vec $ body $
- custom (#_ :-> #v :->
- let_ #m (idx0 (maximum1i #v)) $
- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m)
- (#_ :-> #v :->
- let_ #m (idx0 (maximum1i #v)) $
- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $
- let_ #s (idx0 (sum1i #ex)) $
- pair (log #s + #m)
- (pair #ex #s))
- (#tape :-> #d :->
- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape))
- nil #vec
+ let_ #m (maximum1i #vec) $
+ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
+ -- custom (#_ :-> #v :->
+ -- let_ #m (idx0 (maximum1i #v)) $
+ -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m)
+ -- (#_ :-> #v :->
+ -- let_ #m (idx0 (maximum1i #v)) $
+ -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $
+ -- let_ #s (idx0 (sum1i #ex)) $
+ -- pair (log #s + #m)
+ -- (pair #ex #s))
+ -- (#tape :-> #d :->
+ -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape))
+ -- nil #vec
logsumexp v = inline logsumexp' (SNil .$ v)
mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $
@@ -101,7 +103,8 @@ objective = fromNamed $
(toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j)
0.0)
qmat q l = inline qmat' (SNil .$ q .$ l)
- in - #k1
+ in let_ #k2arr (unit #k2) $
+ - #k1
+ idx0 (sum1i (build1 #N $ #i :->
logsumexp (build1 #K $ #k :->
#alpha ! pair nil #k
@@ -109,6 +112,6 @@ objective = fromNamed $
- 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k))))))
- toFloat_ #N * logsumexp #alpha
+ idx0 (sum1i (build1 #K $ #k :->
- #k2 * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k))
+ idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k))
- toFloat_ #m * idx0 (sum1i (#Q .! #k))))
- #k3
diff --git a/test/Main.hs b/test/Main.hs
index 2573a32..75ab11a 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -10,6 +10,7 @@
module Main where
import Data.Bifunctor
+import Data.Int (Int64)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
@@ -22,6 +23,7 @@ import AST.Pretty
import CHAD.Top
import CHAD.Types
import qualified Example
+import qualified Example.GMM as Example
import ForwardAD
import Interpreter
import Interpreter.Rep
@@ -150,14 +152,14 @@ adTestGen expr envGenerator = property $ do
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
- annotate (ppExpr knownEnv expr)
- annotate ppdterm
- annotate ppdterm_S
- diff ppdterm_S20 (==) ppdterm_S
- diff outChad closeIsh outChad_S
- diff outPrimal closeIsh outChad_S
- diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
- diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
+ -- annotate (ppExpr knownEnv expr)
+ -- annotate ppdterm
+ -- annotate ppdterm_S
+ diff ppdterm_S (==) ppdterm_S20
+ diff outChad_S closeIsh outChad
+ diff outChad_S closeIsh outPrimal
+ diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD
+ diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
@@ -221,6 +223,29 @@ tests = checkSequential $ Group "AD"
lay2 <- genLayer n1 n2
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
+
+ ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do
+ -- The input ranges here are completely arbitrary.
+ let tR = STScal STF64
+ kN <- Gen.integral (Range.linear 1 10)
+ kD <- Gen.integral (Range.linear 1 10)
+ kK <- Gen.integral (Range.linear 1 10)
+ let i2i64 = fromIntegral @Int @Int64
+ valpha <- genArray tR (ShNil `ShCons` kK)
+ vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
+ vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
+ vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ vm <- Gen.integral (Range.linear 0 5)
+ let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
+ k2 = 0.5 * vgamma * vgamma
+ k3 = 0.42 -- don't feel like multigammaing today
+ return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
+ Value vm `SCons` vX `SCons`
+ vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
+ Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
+ SNil))
]
main :: IO ()