diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/test/Main.hs b/test/Main.hs index d3e55b3..3a598c0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -132,7 +132,10 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest = flip adTestGen (genEnv (knownEnv @env)) +adTest = adTestCon (const True) + +adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property +adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property @@ -198,10 +201,13 @@ tests = checkSequential $ Group "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) - -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ - -- idx0 $ sum1i . sum1i $ - -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> - -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) + ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + idx0 $ sum1i $ maximum1i #x) + + ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + idx0 $ sum1i $ minimum1i #x) ,("neural", adTestGen Example.neural $ do let tR = STScal STF64 |