diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 22:40:54 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 22:40:54 +0100 |
commit | a46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (patch) | |
tree | 1f00fa82540f4a54ddbf45fc6e5717b6dd8d5f94 | |
parent | 4d573fa32997a8e4824bf8326fb675d0c195b1ac (diff) |
Test gmm
-rw-r--r-- | chad-fast.cabal | 2 | ||||
-rw-r--r-- | src/Example/GMM.hs | 15 | ||||
-rw-r--r-- | src/ForwardAD.hs | 9 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 10 | ||||
-rw-r--r-- | test/Main.hs | 140 |
5 files changed, 140 insertions, 36 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 274e497..8817718 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -65,8 +65,10 @@ test-suite test build-depends: chad-fast, base, + containers, dependent-map, hedgehog, + transformers, hs-source-dirs: test default-language: Haskell2010 ghc-options: -Wall -threaded diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index ff37f9a..1db88bd 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -32,8 +32,16 @@ type TMat = TArr (S (S Z)) -- Master thesis at Utrecht University. (Appendix B.1) -- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y> -- <https://tomsmeding.com/f/master.pdf> -gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -gmmObjective = fromNamed $ +-- +-- The 'wrong' argument, when set to True, changes the objective function to +-- one with a bug that makes a certain `build` result unused. This triggers +-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's +-- dense, even though it may be a zero (i.e. empty). The "unused" test in +-- test/Main.hs tries to isolate this test, but the wrong version of +-- gmmObjective is here to check (after that bug is fixed) whether it really +-- fixes the original bug. +gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective wrong = fromNamed $ lambda #N $ lambda #D $ lambda #K $ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ lambda #X $ lambda #m $ @@ -100,7 +108,8 @@ gmmObjective = fromNamed $ if_ (#i .== #j) (exp (#q ! pair nil #i)) (if_ (#i .> #j) - (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) + (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) + else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j)) 0.0) qmat q l = inline qmat' (SNil .$ q .$ l) in let_ #k2arr (unit #k2) $ diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 67d22dd..b95385c 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -7,11 +7,12 @@ module ForwardAD where import Data.Bifunctor (bimap) --- import Data.Foldable (toList) + +-- import Debug.Trace +-- import AST.Pretty import Array import AST --- import AST.Bindings import Data import ForwardAD.DualNumbers import Interpreter @@ -214,6 +215,8 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwd env expr input dres = let outty = typeOf expr - in dnOnehotEnvs env input $ \dnInput -> + in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) in dotprodTan outty outtan dres diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 7ef9088..0007991 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -3,7 +3,7 @@ {-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where -import Data.List (intersperse) +import Data.List (intersperse, intercalate) import Data.Foldable (toList) import Data.IORef import GHC.TypeError @@ -11,6 +11,7 @@ import GHC.TypeError import Array import AST import AST.Pretty +import Data type family Rep t where @@ -76,3 +77,10 @@ showValue _ (STScal sty) x = case sty of STI64 -> shows x STBool -> shows x showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" + where + showEntries :: SList STy env -> SList Value env -> [String] + showEntries SNil SNil = [] + showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs diff --git a/test/Main.hs b/test/Main.hs index 75ab11a..72b7809 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -7,11 +7,15 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module Main where +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) -import Data.List (intercalate) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range @@ -86,33 +90,89 @@ closeIsh :: Double -> Double -> Bool closeIsh a b = abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) -genShape :: SNat n -> Gen (Shape n) -genShape = \n -> do - sh <- genShapeNaive n +data a :$ b = a :$ b deriving (Show) ; infixl :$ + +-- An empty name means "no restrictions". +data TplConstr = C String -- ^ name; @""@ means anonymous + Int -- ^ minimum value to generate + +type family DimNames n where + DimNames Z = () + DimNames (S Z) = TplConstr + DimNames (S n) = DimNames n :$ TplConstr + +type family Tpl t where + Tpl (TArr n t) = DimNames n + Tpl (TPair a b) = (Tpl a, Tpl b) + -- If you add equations here, don't forget to update genValue! It currently + -- just emptyTpl's things out. + Tpl _ = () + +data a :& b = a :& b deriving (Show) ; infixl :& + +type family TemplateE env where + TemplateE '[] = () + TemplateE '[t] = Tpl t + TemplateE (t : ts) = TemplateE ts :& Tpl t + +emptyDimNames :: SNat n -> DimNames n +emptyDimNames SZ = () +emptyDimNames (SS SZ) = C "" 0 +emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 + +emptyTpl :: STy t -> Tpl t +emptyTpl (STArr n _) = emptyDimNames n +emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) +emptyTpl (STScal _) = () +emptyTpl _ = error "too lazy" + +emptyTemplateE :: SList STy env -> TemplateE env +emptyTemplateE SNil = () +emptyTemplateE (t `SCons` SNil) = emptyTpl t +emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t + +genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) +genShape = \n tpl -> do + sh <- genShapeNaive n tpl let sz = shapeSize sh factor = sz `div` 100 + 1 return (shapeDiv sh factor) where - genShapeNaive :: SNat n -> Gen (Shape n) - genShapeNaive SZ = return ShNil - genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10) + genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) + genShapeNaive SZ () = return ShNil + genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name + genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name + + genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int + genNamedDim (C "" lo) = genDim lo + genNamedDim (C name lo) = gets (Map.lookup name) >>= \case + Nothing -> do + dim <- genDim lo + modify (Map.insert name dim) + return dim + Just dim -> return dim + + genDim :: Int -> StateT (Map String Int) Gen Int + genDim lo = Gen.integral (Range.linear lo 10) shapeDiv :: Shape n -> Int -> Shape n shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) -genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) +genArray t sh = + Value <$> arrayGenerateLinM sh (\_ -> + unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) -genValue :: STy a -> Gen (Value a) -genValue = \case +genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) +genValue topty tpl = case topty of STNil -> return (Value ()) - STPair a b -> liftV2 (,) <$> genValue a <*> genValue b - STEither a b -> Gen.choice [liftV Left <$> genValue a - ,liftV Right <$> genValue b] + STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) + STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) + ,liftV Right <$> genValue b (emptyTpl b)] STMaybe t -> Gen.choice [return (Value Nothing) - ,liftV Just <$> genValue t] - STArr n t -> genShape n >>= genArray t + ,liftV Just <$> genValue t (emptyTpl t)] + STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) @@ -121,22 +181,22 @@ genValue = \case STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" -genEnv :: SList STy env -> Gen (SList Value env) -genEnv SNil = return SNil -genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env - -showEnv :: SList STy env -> SList Value env -> String -showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" - where - showEntries :: SList STy env -> SList Value env -> [String] - showEntries SNil SNil = [] - showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs +genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) +genEnv SNil () = return SNil +genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil +genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property adTest = adTestCon (const True) adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property -adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) +adTestCon constr term = + let env = knownEnv + in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) + +adTestTp :: forall env. KnownEnv env + => TemplateE env -> Ex env (TScal TF64) -> Property +adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property @@ -210,6 +270,10 @@ tests = checkSequential $ Group "AD" fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ minimum1i #x) + ,("unused", adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ + let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ + 42) + ,("neural", adTestGen Example.neural $ do let tR = STScal STF64 let genLayer nin nout = @@ -224,7 +288,26 @@ tests = checkSequential $ Group "AD" lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) - ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do + ,("logsumexp", adTestTp (C "" 1) $ + fromNamed $ lambda @(TArr N1 _) #vec $ body $ + let_ #m (maximum1i #vec) $ + log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m) + + ,("mulmatvec", adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $ + fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ + idx0 $ sum1i $ + let_ #hei (snd_ (fst_ (shape #mat))) $ + let_ #wid (snd_ (shape #mat)) $ + build1 #hei $ #i :-> + idx0 (sum1i (build1 #wid $ #j :-> + #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))) + + ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) + + ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) + ] + where + genGMM = do -- The input ranges here are completely arbitrary. let tR = STScal STF64 kN <- Gen.integral (Range.linear 1 10) @@ -245,8 +328,7 @@ tests = checkSequential $ Group "AD" Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` - SNil)) - ] + SNil) main :: IO () main = defaultMain [tests] |