diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:35:26 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:35:26 +0100 |
commit | a17bd53598ee5266fc3a1c45f8f4bb4798dc495e (patch) | |
tree | ee7962f603fbb26a0df0f793b8e50666f41a0dfd /test | |
parent | b91d36fa38be07397b505433f24a6d29a79c2642 (diff) |
Working tests and benchmarks against 'ad'
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/test/Main.hs b/test/Main.hs index 04a8923..a04533f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,26 +1,42 @@ {-# LANGUAGE TypeApplications #-} module Main where -import qualified Data.Vector as V +import Data.Foldable (toList) import Hedgehog import Test.Tasty import Test.Tasty.Hedgehog import Test.Tasty.HUnit +import qualified Numeric.AD as AD + import Numeric.ADDual import Numeric.ADDual.Examples +(~==) :: (Foldable t, Fractional a, Ord a, Show (t a)) => t a -> t a -> PropertyT IO () +a ~== b + | length (toList a) == length (toList b) + , and (zipWith close (toList a) (toList b)) + = return () + | otherwise + = diff a (\_ _ -> False) b + where + close x y = abs (x - y) < 1e-5 || + (let m = max (abs x) (abs y) in m > 1e-5 && abs (x - y) / m < 1e-5) + + main :: IO () main = defaultMain $ testGroup "Tests" [testCase "product [1..5]" $ gradient' @Double product [1..5] 1 @?= (120, [120, 60, 40, 30, 24]) - ,testCase "neural one" $ - let problem = FNeural - [(V.replicate 6 0.0, V.replicate 6 0.0), (V.replicate 24 0.0, V.replicate 4 0.0)] - (V.replicate 1 0.0) - in fst (gradient' @Double fneural problem 1) @?= fneural problem - ,testProperty "neural run" $ property $ do - input <- forAll genNeuralInput - let (res, _grad) = gradient' fneural input 1 - res === fneural input] + ,testProperty "neural 80" $ property $ do + input <- forAll (genNeuralInput 80) + let (res, grad) = gradient' fneural input 1 + res === fneural input + grad ~== AD.grad fneural input + ,testProperty "neural 150" $ property $ do + input <- forAll (genNeuralInput 150) + let (res, grad) = gradient' fneural input 1 + res === fneural input + grad ~== AD.grad fneural input + ] |