diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-10-21 23:20:57 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-21 23:20:57 +0200 | 
| commit | e7d7ac0fd8b81c1d6fae9ab7c1e4654133c631ea (patch) | |
| tree | 4dc880e6956b42f0920382d772b49adc2a4ce556 /test/Main.hs | |
| parent | 246439502b78c4a8fcc27ab3296c67471a2b239d (diff) | |
Tests
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 122 | 
1 files changed, 82 insertions, 40 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 986c8a0..d90d9cd 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,20 +1,21 @@  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} +-- {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE OverloadedLabels #-}  {-# LANGUAGE OverloadedStrings #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} +-- {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  module Main where  import Data.Bifunctor -import qualified Data.Dependent.Map as DMap -import Data.Dependent.Map (DMap) -import Data.List (intercalate) +-- import qualified Data.Dependent.Map as DMap +-- import Data.Dependent.Map (DMap) +import Data.Foldable (toList) +import Data.List (intercalate, intersperse)  import Hedgehog  import qualified Hedgehog.Gen as Gen  import qualified Hedgehog.Range as Range @@ -52,7 +53,7 @@ gradientByCHAD = \env term input ->            dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)            input1 = toPrimalE env input            (_out, grad) = interpretOpen input1 dterm -      in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) +      in unTup vUnpair (d2e env) (Value grad)    where      makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')      makeMergeDescr SNil = DTop @@ -127,17 +128,18 @@ genShape = \n -> do      shapeDiv ShNil _ = ShNil      shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) +genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) +genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) +  genValue :: STy a -> Gen (Value a)  genValue = \case    STNil -> return (Value ()) -  STPair a b -> lv2 (,) <$> genValue a <*> genValue b -  STEither a b -> Gen.choice [lv1 Left <$> genValue a -                             ,lv1 Right <$> genValue b] +  STPair a b -> liftV2 (,) <$> genValue a <*> genValue b +  STEither a b -> Gen.choice [liftV Left <$> genValue a +                             ,liftV Right <$> genValue b]    STMaybe t -> Gen.choice [return (Value Nothing) -                          ,lv1 Just <$> genValue t] -  STArr n t -> do -    sh <- genShape n -    Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t) +                          ,liftV Just <$> genValue t] +  STArr n t -> genShape n >>= genArray t    STScal sty -> case sty of      STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)      STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) @@ -145,39 +147,33 @@ genValue = \case      STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)      STBool -> Gen.choice [return (Value False), return (Value True)]    STAccum{} -> error "Cannot generate inputs for accumulators" -  where -    lv1 :: (Rep a -> Rep b) -> Value a -> Value b -    lv1 f (Value x) = Value (f x) - -    lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c -    lv2 f (Value x) (Value y) = Value (f x y)  genEnv :: SList STy env -> Gen (SList Value env)  genEnv SNil = return SNil  genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env -data TemplateVar n = TemplateVar (SNat n) String -  deriving (Show) +-- data TemplateVar n = TemplateVar (SNat n) String +--   deriving (Show) -data Template t where -  TpShape :: TemplateVar n -> STy t -> Template (TArr n t) -  TpAny :: STy t -> Template t -  TpPair :: Template a -> Template b -> Template (TPair a b) -deriving instance Show (Template t) +-- data Template t where +--   TpShape :: TemplateVar n -> STy t -> Template (TArr n t) +--   TpAny :: STy t -> Template t +--   TpPair :: Template a -> Template b -> Template (TPair a b) +-- deriving instance Show (Template t) -data ShapeConstraint n = ShapeAtLeast (Shape n) -  deriving (Show) +-- data ShapeConstraint n = ShapeAtLeast (Shape n) +--   deriving (Show) -genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) -genTemplate = _ +-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) +-- genTemplate = _ -genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) -genEnvTemplateExact shapes env = _ +-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) +-- genEnvTemplateExact shapes env = _ -genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) -genEnvTemplate constrs env = do -  shapes <- DMap.traverseWithKey _ constrs -  genEnvTemplateExact shapes env +-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) +-- genEnvTemplate constrs env = do +--   shapes <- DMap.traverseWithKey _ constrs +--   genEnvTemplateExact shapes env  showValue :: Int -> STy t -> Rep t -> ShowS  showValue _ STNil () = showString "()" @@ -186,7 +182,11 @@ showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " .  showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y  showValue _ (STMaybe _) Nothing = showString "Nothing"  showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x -showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr)  -- TODO: improve +showValue d (STArr _ t) arr = showParen (d > 10) $ +  showString "arrayFromList " . showsPrec 11 (arrayShape arr) +  . showString " [" +  . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) +  . showString "]"  showValue _ (STScal sty) x = case sty of    STF32 -> shows x    STF64 -> shows x @@ -203,9 +203,18 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"      showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs  adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest expr = property $ do +adTest = flip adTestGen (genEnv (knownEnv @env)) + +-- adTestTp :: forall env. KnownEnv env +--          => DMap TemplateVar ShapeConstraint -> SList Template env +--          -> Ex env (TScal TF64) -> Property +-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp) + +adTestGen :: forall env. KnownEnv env +          => Ex env (TScal TF64) -> Gen (SList Value env) -> Property +adTestGen expr envGenerator = property $ do    let env = knownEnv @env -  input <- forAllWith (showEnv env) $ genEnv env +  input <- forAllWith (showEnv env) envGenerator    let gradFwd = gradientByForward knownEnv expr input        gradCHAD = gradientByCHAD' knownEnv expr input        scFwd = envScalars env gradFwd @@ -219,7 +228,40 @@ adTest expr = property $ do  tests :: IO Bool  tests = checkParallel $ Group "AD"    [("id", adTest $ fromNamed $ lambda #x $ body $ #x) -  ,("neural", adTest Example.neural)] + +  ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) + +  ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ +      idx0 $ +        build SZ (shape #x) $ #idx :-> #x ! #idx) + +  ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $ +      idx0 $ sum1i $ +        build (SS SZ) (shape #x) $ #idx :-> #x ! #idx) + +  ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ +      idx0 $ sum1i . sum1i $ +        build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) + +  -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ +  --     idx0 $ sum1i . sum1i $ +  --       build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> +  --         oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) + +  -- ,("neural", adTestGen Example.neural $ do +  --     let tR = STScal STF64 +  --     let genLayer nin nout = +  --           liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) +  --                      <*> genArray tR (ShNil `ShCons` nout) +  --     nin <- Gen.integral (Range.linear 1 10) +  --     n1 <- Gen.integral (Range.linear 1 10) +  --     n2 <- Gen.integral (Range.linear 1 10) +  --     input <- genArray tR (ShNil `ShCons` nin) +  --     lay1 <- genLayer nin n1 +  --     lay2 <- genLayer n1 n2 +  --     lay3 <- genArray tR (ShNil `ShCons` n2) +  --     return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) +  ]  main :: IO ()  main = defaultMain [tests] | 
