summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-21 23:20:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-21 23:20:57 +0200
commite7d7ac0fd8b81c1d6fae9ab7c1e4654133c631ea (patch)
tree4dc880e6956b42f0920382d772b49adc2a4ce556 /test
parent246439502b78c4a8fcc27ab3296c67471a2b239d (diff)
Tests
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs122
1 files changed, 82 insertions, 40 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 986c8a0..d90d9cd 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,20 +1,21 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
+-- {-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
+-- {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Data.Bifunctor
-import qualified Data.Dependent.Map as DMap
-import Data.Dependent.Map (DMap)
-import Data.List (intercalate)
+-- import qualified Data.Dependent.Map as DMap
+-- import Data.Dependent.Map (DMap)
+import Data.Foldable (toList)
+import Data.List (intercalate, intersperse)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
@@ -52,7 +53,7 @@ gradientByCHAD = \env term input ->
dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
input1 = toPrimalE env input
(_out, grad) = interpretOpen input1 dterm
- in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
+ in unTup vUnpair (d2e env) (Value grad)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
@@ -127,17 +128,18 @@ genShape = \n -> do
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
+genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
+genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t)
+
genValue :: STy a -> Gen (Value a)
genValue = \case
STNil -> return (Value ())
- STPair a b -> lv2 (,) <$> genValue a <*> genValue b
- STEither a b -> Gen.choice [lv1 Left <$> genValue a
- ,lv1 Right <$> genValue b]
+ STPair a b -> liftV2 (,) <$> genValue a <*> genValue b
+ STEither a b -> Gen.choice [liftV Left <$> genValue a
+ ,liftV Right <$> genValue b]
STMaybe t -> Gen.choice [return (Value Nothing)
- ,lv1 Just <$> genValue t]
- STArr n t -> do
- sh <- genShape n
- Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t)
+ ,liftV Just <$> genValue t]
+ STArr n t -> genShape n >>= genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
@@ -145,39 +147,33 @@ genValue = \case
STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
- where
- lv1 :: (Rep a -> Rep b) -> Value a -> Value b
- lv1 f (Value x) = Value (f x)
-
- lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
- lv2 f (Value x) (Value y) = Value (f x y)
genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
-data TemplateVar n = TemplateVar (SNat n) String
- deriving (Show)
+-- data TemplateVar n = TemplateVar (SNat n) String
+-- deriving (Show)
-data Template t where
- TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
- TpAny :: STy t -> Template t
- TpPair :: Template a -> Template b -> Template (TPair a b)
-deriving instance Show (Template t)
+-- data Template t where
+-- TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
+-- TpAny :: STy t -> Template t
+-- TpPair :: Template a -> Template b -> Template (TPair a b)
+-- deriving instance Show (Template t)
-data ShapeConstraint n = ShapeAtLeast (Shape n)
- deriving (Show)
+-- data ShapeConstraint n = ShapeAtLeast (Shape n)
+-- deriving (Show)
-genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
-genTemplate = _
+-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
+-- genTemplate = _
-genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
-genEnvTemplateExact shapes env = _
+-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
+-- genEnvTemplateExact shapes env = _
-genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
-genEnvTemplate constrs env = do
- shapes <- DMap.traverseWithKey _ constrs
- genEnvTemplateExact shapes env
+-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
+-- genEnvTemplate constrs env = do
+-- shapes <- DMap.traverseWithKey _ constrs
+-- genEnvTemplateExact shapes env
showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
@@ -186,7 +182,11 @@ showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " .
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
-showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve
+showValue d (STArr _ t) arr = showParen (d > 10) $
+ showString "arrayFromList " . showsPrec 11 (arrayShape arr)
+ . showString " ["
+ . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
+ . showString "]"
showValue _ (STScal sty) x = case sty of
STF32 -> shows x
STF64 -> shows x
@@ -203,9 +203,18 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
-adTest expr = property $ do
+adTest = flip adTestGen (genEnv (knownEnv @env))
+
+-- adTestTp :: forall env. KnownEnv env
+-- => DMap TemplateVar ShapeConstraint -> SList Template env
+-- -> Ex env (TScal TF64) -> Property
+-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp)
+
+adTestGen :: forall env. KnownEnv env
+ => Ex env (TScal TF64) -> Gen (SList Value env) -> Property
+adTestGen expr envGenerator = property $ do
let env = knownEnv @env
- input <- forAllWith (showEnv env) $ genEnv env
+ input <- forAllWith (showEnv env) envGenerator
let gradFwd = gradientByForward knownEnv expr input
gradCHAD = gradientByCHAD' knownEnv expr input
scFwd = envScalars env gradFwd
@@ -219,7 +228,40 @@ adTest expr = property $ do
tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
- ,("neural", adTest Example.neural)]
+
+ ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))
+
+ ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
+ idx0 $
+ build SZ (shape #x) $ #idx :-> #x ! #idx)
+
+ ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $
+ idx0 $ sum1i $
+ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx)
+
+ ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
+ idx0 $ sum1i . sum1i $
+ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)
+
+ -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
+ -- idx0 $ sum1i . sum1i $
+ -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
+ -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)
+
+ -- ,("neural", adTestGen Example.neural $ do
+ -- let tR = STScal STF64
+ -- let genLayer nin nout =
+ -- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
+ -- <*> genArray tR (ShNil `ShCons` nout)
+ -- nin <- Gen.integral (Range.linear 1 10)
+ -- n1 <- Gen.integral (Range.linear 1 10)
+ -- n2 <- Gen.integral (Range.linear 1 10)
+ -- input <- genArray tR (ShNil `ShCons` nin)
+ -- lay1 <- genLayer nin n1
+ -- lay2 <- genLayer n1 n2
+ -- lay3 <- genArray tR (ShNil `ShCons` n2)
+ -- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
+ ]
main :: IO ()
main = defaultMain [tests]