blob: a04533f9269f54ea7f6a336554ecfa4c78ec5e43 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
{-# LANGUAGE TypeApplications #-}
module Main where
import Data.Foldable (toList)
import Hedgehog
import Test.Tasty
import Test.Tasty.Hedgehog
import Test.Tasty.HUnit
import qualified Numeric.AD as AD
import Numeric.ADDual
import Numeric.ADDual.Examples
(~==) :: (Foldable t, Fractional a, Ord a, Show (t a)) => t a -> t a -> PropertyT IO ()
a ~== b
| length (toList a) == length (toList b)
, and (zipWith close (toList a) (toList b))
= return ()
| otherwise
= diff a (\_ _ -> False) b
where
close x y = abs (x - y) < 1e-5 ||
(let m = max (abs x) (abs y) in m > 1e-5 && abs (x - y) / m < 1e-5)
main :: IO ()
main = defaultMain $ testGroup "Tests"
[testCase "product [1..5]" $
gradient' @Double product [1..5] 1 @?= (120, [120, 60, 40, 30, 24])
,testProperty "neural 80" $ property $ do
input <- forAll (genNeuralInput 80)
let (res, grad) = gradient' fneural input 1
res === fneural input
grad ~== AD.grad fneural input
,testProperty "neural 150" $ property $ do
input <- forAll (genNeuralInput 150)
let (res, grad) = gradient' fneural input 1
res === fneural input
grad ~== AD.grad fneural input
]
|