aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 04a89233cd21d5d2d29ad93b0295f67de43d5516 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
{-# LANGUAGE TypeApplications #-}
module Main where

import qualified Data.Vector as V
import Hedgehog
import Test.Tasty
import Test.Tasty.Hedgehog
import Test.Tasty.HUnit

import Numeric.ADDual
import Numeric.ADDual.Examples


main :: IO ()
main = defaultMain $ testGroup "Tests"
  [testCase "product [1..5]" $
      gradient' @Double product [1..5] 1 @?= (120, [120, 60, 40, 30, 24])
  ,testCase "neural one" $
      let problem = FNeural
            [(V.replicate 6 0.0, V.replicate 6 0.0), (V.replicate 24 0.0, V.replicate 4 0.0)]
            (V.replicate 1 0.0)
      in fst (gradient' @Double fneural problem 1) @?= fneural problem
  ,testProperty "neural run" $ property $ do
      input <- forAll genNeuralInput
      let (res, _grad) = gradient' fneural input 1
      res === fneural input]