diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-21 23:20:57 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-21 23:20:57 +0200 |
commit | e7d7ac0fd8b81c1d6fae9ab7c1e4654133c631ea (patch) | |
tree | 4dc880e6956b42f0920382d772b49adc2a4ce556 /test | |
parent | 246439502b78c4a8fcc27ab3296c67471a2b239d (diff) |
Tests
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 122 |
1 files changed, 82 insertions, 40 deletions
diff --git a/test/Main.hs b/test/Main.hs index 986c8a0..d90d9cd 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,20 +1,21 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} +-- {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} +-- {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Main where import Data.Bifunctor -import qualified Data.Dependent.Map as DMap -import Data.Dependent.Map (DMap) -import Data.List (intercalate) +-- import qualified Data.Dependent.Map as DMap +-- import Data.Dependent.Map (DMap) +import Data.Foldable (toList) +import Data.List (intercalate, intersperse) import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range @@ -52,7 +53,7 @@ gradientByCHAD = \env term input -> dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) input1 = toPrimalE env input (_out, grad) = interpretOpen input1 dterm - in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) + in unTup vUnpair (d2e env) (Value grad) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop @@ -127,17 +128,18 @@ genShape = \n -> do shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) +genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) +genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) + genValue :: STy a -> Gen (Value a) genValue = \case STNil -> return (Value ()) - STPair a b -> lv2 (,) <$> genValue a <*> genValue b - STEither a b -> Gen.choice [lv1 Left <$> genValue a - ,lv1 Right <$> genValue b] + STPair a b -> liftV2 (,) <$> genValue a <*> genValue b + STEither a b -> Gen.choice [liftV Left <$> genValue a + ,liftV Right <$> genValue b] STMaybe t -> Gen.choice [return (Value Nothing) - ,lv1 Just <$> genValue t] - STArr n t -> do - sh <- genShape n - Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t) + ,liftV Just <$> genValue t] + STArr n t -> genShape n >>= genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) @@ -145,39 +147,33 @@ genValue = \case STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" - where - lv1 :: (Rep a -> Rep b) -> Value a -> Value b - lv1 f (Value x) = Value (f x) - - lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c - lv2 f (Value x) (Value y) = Value (f x y) genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env -data TemplateVar n = TemplateVar (SNat n) String - deriving (Show) +-- data TemplateVar n = TemplateVar (SNat n) String +-- deriving (Show) -data Template t where - TpShape :: TemplateVar n -> STy t -> Template (TArr n t) - TpAny :: STy t -> Template t - TpPair :: Template a -> Template b -> Template (TPair a b) -deriving instance Show (Template t) +-- data Template t where +-- TpShape :: TemplateVar n -> STy t -> Template (TArr n t) +-- TpAny :: STy t -> Template t +-- TpPair :: Template a -> Template b -> Template (TPair a b) +-- deriving instance Show (Template t) -data ShapeConstraint n = ShapeAtLeast (Shape n) - deriving (Show) +-- data ShapeConstraint n = ShapeAtLeast (Shape n) +-- deriving (Show) -genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) -genTemplate = _ +-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) +-- genTemplate = _ -genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) -genEnvTemplateExact shapes env = _ +-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) +-- genEnvTemplateExact shapes env = _ -genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) -genEnvTemplate constrs env = do - shapes <- DMap.traverseWithKey _ constrs - genEnvTemplateExact shapes env +-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) +-- genEnvTemplate constrs env = do +-- shapes <- DMap.traverseWithKey _ constrs +-- genEnvTemplateExact shapes env showValue :: Int -> STy t -> Rep t -> ShowS showValue _ STNil () = showString "()" @@ -186,7 +182,11 @@ showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x -showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve +showValue d (STArr _ t) arr = showParen (d > 10) $ + showString "arrayFromList " . showsPrec 11 (arrayShape arr) + . showString " [" + . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) + . showString "]" showValue _ (STScal sty) x = case sty of STF32 -> shows x STF64 -> shows x @@ -203,9 +203,18 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest expr = property $ do +adTest = flip adTestGen (genEnv (knownEnv @env)) + +-- adTestTp :: forall env. KnownEnv env +-- => DMap TemplateVar ShapeConstraint -> SList Template env +-- -> Ex env (TScal TF64) -> Property +-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp) + +adTestGen :: forall env. KnownEnv env + => Ex env (TScal TF64) -> Gen (SList Value env) -> Property +adTestGen expr envGenerator = property $ do let env = knownEnv @env - input <- forAllWith (showEnv env) $ genEnv env + input <- forAllWith (showEnv env) envGenerator let gradFwd = gradientByForward knownEnv expr input gradCHAD = gradientByCHAD' knownEnv expr input scFwd = envScalars env gradFwd @@ -219,7 +228,40 @@ adTest expr = property $ do tests :: IO Bool tests = checkParallel $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) - ,("neural", adTest Example.neural)] + + ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) + + ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx) + + ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $ + idx0 $ sum1i $ + build (SS SZ) (shape #x) $ #idx :-> #x ! #idx) + + ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ + idx0 $ sum1i . sum1i $ + build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) + + -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ + -- idx0 $ sum1i . sum1i $ + -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> + -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) + + -- ,("neural", adTestGen Example.neural $ do + -- let tR = STScal STF64 + -- let genLayer nin nout = + -- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + -- <*> genArray tR (ShNil `ShCons` nout) + -- nin <- Gen.integral (Range.linear 1 10) + -- n1 <- Gen.integral (Range.linear 1 10) + -- n2 <- Gen.integral (Range.linear 1 10) + -- input <- genArray tR (ShNil `ShCons` nin) + -- lay1 <- genLayer nin n1 + -- lay2 <- genLayer n1 n2 + -- lay3 <- genArray tR (ShNil `ShCons` n2) + -- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + ] main :: IO () main = defaultMain [tests] |