{-# LANGUAGE TypeApplications #-} module Main where import Data.Foldable (toList) import Hedgehog import Test.Tasty import Test.Tasty.Hedgehog import Test.Tasty.HUnit import qualified Numeric.AD as AD import Numeric.ADDual import Numeric.ADDual.Examples (~==) :: (Foldable t, Fractional a, Ord a, Show (t a)) => t a -> t a -> PropertyT IO () a ~== b | length (toList a) == length (toList b) , and (zipWith close (toList a) (toList b)) = return () | otherwise = diff a (\_ _ -> False) b where close x y = abs (x - y) < 1e-5 || (let m = max (abs x) (abs y) in m > 1e-5 && abs (x - y) / m < 1e-5) main :: IO () main = defaultMain $ testGroup "Tests" [testCase "product [1..5]" $ gradient' @Double product [1..5] 1 @?= (120, [120, 60, 40, 30, 24]) ,testProperty "neural 80" $ property $ do input <- forAll (genNeuralInput 80) let (res, grad) = gradient' fneural input 1 res === fneural input grad ~== AD.grad fneural input ,testProperty "neural 150" $ property $ do input <- forAll (genNeuralInput 150) let (res, grad) = gradient' fneural input 1 res === fneural input grad ~== AD.grad fneural input ]