aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: a04533f9269f54ea7f6a336554ecfa4c78ec5e43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
{-# LANGUAGE TypeApplications #-}
module Main where

import Data.Foldable (toList)
import Hedgehog
import Test.Tasty
import Test.Tasty.Hedgehog
import Test.Tasty.HUnit

import qualified Numeric.AD as AD

import Numeric.ADDual
import Numeric.ADDual.Examples


(~==) :: (Foldable t, Fractional a, Ord a, Show (t a)) => t a -> t a -> PropertyT IO ()
a ~== b
  | length (toList a) == length (toList b)
  , and (zipWith close (toList a) (toList b))
  = return ()
  | otherwise
  = diff a (\_ _ -> False) b
  where
    close x y = abs (x - y) < 1e-5 ||
                  (let m = max (abs x) (abs y) in m > 1e-5 && abs (x - y) / m < 1e-5)


main :: IO ()
main = defaultMain $ testGroup "Tests"
  [testCase "product [1..5]" $
      gradient' @Double product [1..5] 1 @?= (120, [120, 60, 40, 30, 24])
  ,testProperty "neural 80" $ property $ do
      input <- forAll (genNeuralInput 80)
      let (res, grad) = gradient' fneural input 1
      res === fneural input
      grad ~== AD.grad fneural input
  ,testProperty "neural 150" $ property $ do
      input <- forAll (genNeuralInput 150)
      let (res, grad) = gradient' fneural input 1
      res === fneural input
      grad ~== AD.grad fneural input
  ]