From 992249ebf159ba3783a9345430013e52294c26aa Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 9 Nov 2024 11:15:06 +0100 Subject: Maximum/minimum --- test/Main.hs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/test/Main.hs b/test/Main.hs index d3e55b3..3a598c0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -132,7 +132,10 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest = flip adTestGen (genEnv (knownEnv @env)) +adTest = adTestCon (const True) + +adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property +adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property @@ -198,10 +201,13 @@ tests = checkSequential $ Group "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) - -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ - -- idx0 $ sum1i . sum1i $ - -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> - -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) + ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + idx0 $ sum1i $ maximum1i #x) + + ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + idx0 $ sum1i $ minimum1i #x) ,("neural", adTestGen Example.neural $ do let tR = STScal STF64 -- cgit v1.2.3-70-g09d2