summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs28
1 files changed, 26 insertions, 2 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 39415bb..045ac1c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,13 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE LambdaCase #-}
module Main where
import Data.Bifunctor
+import Hedgehog
+import Hedgehog.Main
import Array
import AST
@@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
+closeIsh :: Double -> Double -> Bool
+closeIsh a b =
+ abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
+
+adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property
+adTest input expr = property $
+ let env = knownEnv @env
+ gradFwd = gradientByForward knownEnv expr input
+ gradCHAD = gradientByCHAD' knownEnv expr input
+ scFwd = envScalars env gradFwd
+ scCHAD = envScalars env gradCHAD
+ in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
+ where
+ envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
+ envScalars SNil SNil = []
+ envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+
+tests :: IO Bool
+tests = checkParallel $ Group "AD"
+ [("id", adTest (Value 42.0))]
+
main :: IO ()
-main = return ()
+main = defaultMain [tests]