diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 22:40:54 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 22:40:54 +0100 | 
| commit | a46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (patch) | |
| tree | 1f00fa82540f4a54ddbf45fc6e5717b6dd8d5f94 /test | |
| parent | 4d573fa32997a8e4824bf8326fb675d0c195b1ac (diff) | |
Test gmm
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 140 | 
1 files changed, 111 insertions, 29 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 75ab11a..72b7809 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -7,11 +7,15 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-}  module Main where +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State  import Data.Bifunctor  import Data.Int (Int64) -import Data.List (intercalate) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map  import Hedgehog  import qualified Hedgehog.Gen as Gen  import qualified Hedgehog.Range as Range @@ -86,33 +90,89 @@ closeIsh :: Double -> Double -> Bool  closeIsh a b =    abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) -genShape :: SNat n -> Gen (Shape n) -genShape = \n -> do -  sh <- genShapeNaive n +data a :$ b = a :$ b deriving (Show) ; infixl :$ + +-- An empty name means "no restrictions". +data TplConstr = C String  -- ^ name; @""@ means anonymous +                   Int     -- ^ minimum value to generate + +type family DimNames n where +  DimNames Z = () +  DimNames (S Z) = TplConstr +  DimNames (S n) = DimNames n :$ TplConstr + +type family Tpl t where +  Tpl (TArr n t) = DimNames n +  Tpl (TPair a b) = (Tpl a, Tpl b) +  -- If you add equations here, don't forget to update genValue! It currently +  -- just emptyTpl's things out. +  Tpl _ = () + +data a :& b = a :& b deriving (Show) ; infixl :& + +type family TemplateE env where +  TemplateE '[] = () +  TemplateE '[t] = Tpl t +  TemplateE (t : ts) = TemplateE ts :& Tpl t + +emptyDimNames :: SNat n -> DimNames n +emptyDimNames SZ = () +emptyDimNames (SS SZ) = C "" 0 +emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 + +emptyTpl :: STy t -> Tpl t +emptyTpl (STArr n _) = emptyDimNames n +emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) +emptyTpl (STScal _) = () +emptyTpl _ = error "too lazy" + +emptyTemplateE :: SList STy env -> TemplateE env +emptyTemplateE SNil = () +emptyTemplateE (t `SCons` SNil) = emptyTpl t +emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t + +genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) +genShape = \n tpl -> do +  sh <- genShapeNaive n tpl    let sz = shapeSize sh        factor = sz `div` 100 + 1    return (shapeDiv sh factor)    where -    genShapeNaive :: SNat n -> Gen (Shape n) -    genShapeNaive SZ = return ShNil -    genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10) +    genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) +    genShapeNaive SZ () = return ShNil +    genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name +    genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name + +    genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int +    genNamedDim (C "" lo) = genDim lo +    genNamedDim (C name lo) = gets (Map.lookup name) >>= \case +      Nothing -> do +        dim <- genDim lo +        modify (Map.insert name dim) +        return dim +      Just dim -> return dim + +    genDim :: Int -> StateT (Map String Int) Gen Int +    genDim lo = Gen.integral (Range.linear lo 10)      shapeDiv :: Shape n -> Int -> Shape n      shapeDiv ShNil _ = ShNil      shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)  genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) -genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) +genArray t sh = +  Value <$> arrayGenerateLinM sh (\_ -> +              unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) -genValue :: STy a -> Gen (Value a) -genValue = \case +genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) +genValue topty tpl = case topty of    STNil -> return (Value ()) -  STPair a b -> liftV2 (,) <$> genValue a <*> genValue b -  STEither a b -> Gen.choice [liftV Left <$> genValue a -                             ,liftV Right <$> genValue b] +  STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) +  STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) +                             ,liftV Right <$> genValue b (emptyTpl b)]    STMaybe t -> Gen.choice [return (Value Nothing) -                          ,liftV Just <$> genValue t] -  STArr n t -> genShape n >>= genArray t +                          ,liftV Just <$> genValue t (emptyTpl t)] +  STArr n t -> genShape n tpl >>= lift . genArray t    STScal sty -> case sty of      STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)      STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) @@ -121,22 +181,22 @@ genValue = \case      STBool -> Gen.choice [return (Value False), return (Value True)]    STAccum{} -> error "Cannot generate inputs for accumulators" -genEnv :: SList STy env -> Gen (SList Value env) -genEnv SNil = return SNil -genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env - -showEnv :: SList STy env -> SList Value env -> String -showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" -  where -    showEntries :: SList STy env -> SList Value env -> [String] -    showEntries SNil SNil = [] -    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs +genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) +genEnv SNil () = return SNil +genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil +genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl  adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property  adTest = adTestCon (const True)  adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property -adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) +adTestCon constr term = +  let env = knownEnv +  in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) + +adTestTp :: forall env. KnownEnv env +         => TemplateE env -> Ex env (TScal TF64) -> Property +adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty)  adTestGen :: forall env. KnownEnv env            => Ex env (TScal TF64) -> Gen (SList Value env) -> Property @@ -210,6 +270,10 @@ tests = checkSequential $ Group "AD"        fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $          idx0 $ sum1i $ minimum1i #x) +  ,("unused", adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ +    let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ +      42) +    ,("neural", adTestGen Example.neural $ do        let tR = STScal STF64        let genLayer nin nout = @@ -224,7 +288,26 @@ tests = checkSequential $ Group "AD"        lay3 <- genArray tR (ShNil `ShCons` n2)        return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) -  ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do +  ,("logsumexp", adTestTp (C "" 1) $ +      fromNamed $ lambda @(TArr N1 _) #vec $ body $ +      let_ #m (maximum1i #vec) $ +        log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m) + +  ,("mulmatvec", adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $ +      fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ +      idx0 $ sum1i $ +        let_ #hei (snd_ (fst_ (shape #mat))) $ +        let_ #wid (snd_ (shape #mat)) $ +          build1 #hei $ #i :-> +            idx0 (sum1i (build1 #wid $ #j :-> +                           #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))) + +  ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) + +  ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) +  ] +  where +    genGMM = do        -- The input ranges here are completely arbitrary.        let tR = STScal STF64        kN <- Gen.integral (Range.linear 1 10) @@ -245,8 +328,7 @@ tests = checkSequential $ Group "AD"                Value vm `SCons` vX `SCons`                vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`                Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` -              SNil)) -  ] +              SNil)  main :: IO ()  main = defaultMain [tests] | 
