summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs16
1 files changed, 11 insertions, 5 deletions
diff --git a/test/Main.hs b/test/Main.hs
index d3e55b3..3a598c0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -132,7 +132,10 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
-adTest = flip adTestGen (genEnv (knownEnv @env))
+adTest = adTestCon (const True)
+
+adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property
+adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env)))
adTestGen :: forall env. KnownEnv env
=> Ex env (TScal TF64) -> Gen (SList Value env) -> Property
@@ -198,10 +201,13 @@ tests = checkSequential $ Group "AD"
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)
- -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
- -- idx0 $ sum1i . sum1i $
- -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
- -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)
+ ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
+ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ idx0 $ sum1i $ maximum1i #x)
+
+ ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
+ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ idx0 $ sum1i $ minimum1i #x)
,("neural", adTestGen Example.neural $ do
let tR = STScal STF64