summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-10 22:40:54 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-10 22:40:54 +0100
commita46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (patch)
tree1f00fa82540f4a54ddbf45fc6e5717b6dd8d5f94
parent4d573fa32997a8e4824bf8326fb675d0c195b1ac (diff)
Test gmm
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/Example/GMM.hs15
-rw-r--r--src/ForwardAD.hs9
-rw-r--r--src/Interpreter/Rep.hs10
-rw-r--r--test/Main.hs140
5 files changed, 140 insertions, 36 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 274e497..8817718 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -65,8 +65,10 @@ test-suite test
build-depends:
chad-fast,
base,
+ containers,
dependent-map,
hedgehog,
+ transformers,
hs-source-dirs: test
default-language: Haskell2010
ghc-options: -Wall -threaded
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index ff37f9a..1db88bd 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -32,8 +32,16 @@ type TMat = TArr (S (S Z))
-- Master thesis at Utrecht University. (Appendix B.1)
-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
-- <https://tomsmeding.com/f/master.pdf>
-gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
-gmmObjective = fromNamed $
+--
+-- The 'wrong' argument, when set to True, changes the objective function to
+-- one with a bug that makes a certain `build` result unused. This triggers
+-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
+-- dense, even though it may be a zero (i.e. empty). The "unused" test in
+-- test/Main.hs tries to isolate this test, but the wrong version of
+-- gmmObjective is here to check (after that bug is fixed) whether it really
+-- fixes the original bug.
+gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
+gmmObjective wrong = fromNamed $
lambda #N $ lambda #D $ lambda #K $
lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $
lambda #X $ lambda #m $
@@ -100,7 +108,8 @@ gmmObjective = fromNamed $
if_ (#i .== #j)
(exp (#q ! pair nil #i))
(if_ (#i .> #j)
- (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j)
+ (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j)
+ else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j))
0.0)
qmat q l = inline qmat' (SNil .$ q .$ l)
in let_ #k2arr (unit #k2) $
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 67d22dd..b95385c 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -7,11 +7,12 @@
module ForwardAD where
import Data.Bifunctor (bimap)
--- import Data.Foldable (toList)
+
+-- import Debug.Trace
+-- import AST.Pretty
import Array
import AST
--- import AST.Bindings
import Data
import ForwardAD.DualNumbers
import Interpreter
@@ -214,6 +215,8 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
drevByFwd env expr input dres =
let outty = typeOf expr
- in dnOnehotEnvs env input $ \dnInput ->
+ in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
in dotprodTan outty outtan dres
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 7ef9088..0007991 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -3,7 +3,7 @@
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
-import Data.List (intersperse)
+import Data.List (intersperse, intercalate)
import Data.Foldable (toList)
import Data.IORef
import GHC.TypeError
@@ -11,6 +11,7 @@ import GHC.TypeError
import Array
import AST
import AST.Pretty
+import Data
type family Rep t where
@@ -76,3 +77,10 @@ showValue _ (STScal sty) x = case sty of
STI64 -> shows x
STBool -> shows x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">"
+
+showEnv :: SList STy env -> SList Value env -> String
+showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
+ where
+ showEntries :: SList STy env -> SList Value env -> [String]
+ showEntries SNil SNil = []
+ showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
diff --git a/test/Main.hs b/test/Main.hs
index 75ab11a..72b7809 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -7,11 +7,15 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module Main where
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
-import Data.List (intercalate)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
@@ -86,33 +90,89 @@ closeIsh :: Double -> Double -> Bool
closeIsh a b =
abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
-genShape :: SNat n -> Gen (Shape n)
-genShape = \n -> do
- sh <- genShapeNaive n
+data a :$ b = a :$ b deriving (Show) ; infixl :$
+
+-- An empty name means "no restrictions".
+data TplConstr = C String -- ^ name; @""@ means anonymous
+ Int -- ^ minimum value to generate
+
+type family DimNames n where
+ DimNames Z = ()
+ DimNames (S Z) = TplConstr
+ DimNames (S n) = DimNames n :$ TplConstr
+
+type family Tpl t where
+ Tpl (TArr n t) = DimNames n
+ Tpl (TPair a b) = (Tpl a, Tpl b)
+ -- If you add equations here, don't forget to update genValue! It currently
+ -- just emptyTpl's things out.
+ Tpl _ = ()
+
+data a :& b = a :& b deriving (Show) ; infixl :&
+
+type family TemplateE env where
+ TemplateE '[] = ()
+ TemplateE '[t] = Tpl t
+ TemplateE (t : ts) = TemplateE ts :& Tpl t
+
+emptyDimNames :: SNat n -> DimNames n
+emptyDimNames SZ = ()
+emptyDimNames (SS SZ) = C "" 0
+emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0
+
+emptyTpl :: STy t -> Tpl t
+emptyTpl (STArr n _) = emptyDimNames n
+emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
+emptyTpl (STScal _) = ()
+emptyTpl _ = error "too lazy"
+
+emptyTemplateE :: SList STy env -> TemplateE env
+emptyTemplateE SNil = ()
+emptyTemplateE (t `SCons` SNil) = emptyTpl t
+emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t
+
+genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
+genShape = \n tpl -> do
+ sh <- genShapeNaive n tpl
let sz = shapeSize sh
factor = sz `div` 100 + 1
return (shapeDiv sh factor)
where
- genShapeNaive :: SNat n -> Gen (Shape n)
- genShapeNaive SZ = return ShNil
- genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)
+ genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
+ genShapeNaive SZ () = return ShNil
+ genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
+ genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
+
+ genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
+ genNamedDim (C "" lo) = genDim lo
+ genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
+ Nothing -> do
+ dim <- genDim lo
+ modify (Map.insert name dim)
+ return dim
+ Just dim -> return dim
+
+ genDim :: Int -> StateT (Map String Int) Gen Int
+ genDim lo = Gen.integral (Range.linear lo 10)
shapeDiv :: Shape n -> Int -> Shape n
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
-genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t)
+genArray t sh =
+ Value <$> arrayGenerateLinM sh (\_ ->
+ unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
-genValue :: STy a -> Gen (Value a)
-genValue = \case
+genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
+genValue topty tpl = case topty of
STNil -> return (Value ())
- STPair a b -> liftV2 (,) <$> genValue a <*> genValue b
- STEither a b -> Gen.choice [liftV Left <$> genValue a
- ,liftV Right <$> genValue b]
+ STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
+ STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
+ ,liftV Right <$> genValue b (emptyTpl b)]
STMaybe t -> Gen.choice [return (Value Nothing)
- ,liftV Just <$> genValue t]
- STArr n t -> genShape n >>= genArray t
+ ,liftV Just <$> genValue t (emptyTpl t)]
+ STArr n t -> genShape n tpl >>= lift . genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
@@ -121,22 +181,22 @@ genValue = \case
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
-genEnv :: SList STy env -> Gen (SList Value env)
-genEnv SNil = return SNil
-genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
-
-showEnv :: SList STy env -> SList Value env -> String
-showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
- where
- showEntries :: SList STy env -> SList Value env -> [String]
- showEntries SNil SNil = []
- showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
+genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
+genEnv SNil () = return SNil
+genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
+genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = adTestCon (const True)
adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property
-adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env)))
+adTestCon constr term =
+ let env = knownEnv
+ in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
+
+adTestTp :: forall env. KnownEnv env
+ => TemplateE env -> Ex env (TScal TF64) -> Property
+adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
=> Ex env (TScal TF64) -> Gen (SList Value env) -> Property
@@ -210,6 +270,10 @@ tests = checkSequential $ Group "AD"
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ minimum1i #x)
+ ,("unused", adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
+ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
+ 42)
+
,("neural", adTestGen Example.neural $ do
let tR = STScal STF64
let genLayer nin nout =
@@ -224,7 +288,26 @@ tests = checkSequential $ Group "AD"
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
- ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do
+ ,("logsumexp", adTestTp (C "" 1) $
+ fromNamed $ lambda @(TArr N1 _) #vec $ body $
+ let_ #m (maximum1i #vec) $
+ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m)
+
+ ,("mulmatvec", adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $
+ fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
+ idx0 $ sum1i $
+ let_ #hei (snd_ (fst_ (shape #mat))) $
+ let_ #wid (snd_ (shape #mat)) $
+ build1 #hei $ #i :->
+ idx0 (sum1i (build1 #wid $ #j :->
+ #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)))
+
+ ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM)
+
+ ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM)
+ ]
+ where
+ genGMM = do
-- The input ranges here are completely arbitrary.
let tR = STScal STF64
kN <- Gen.integral (Range.linear 1 10)
@@ -245,8 +328,7 @@ tests = checkSequential $ Group "AD"
Value vm `SCons` vX `SCons`
vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
- SNil))
- ]
+ SNil)
main :: IO ()
main = defaultMain [tests]