diff options
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 28 | 
1 files changed, 26 insertions, 2 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 39415bb..045ac1c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,13 +1,16 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE LambdaCase #-}  module Main where  import Data.Bifunctor +import Hedgehog +import Hedgehog.Main  import Array  import AST @@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i  gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)  gradientByForward env term input = drevByFwd env term input 1.0 +closeIsh :: Double -> Double -> Bool +closeIsh a b = +  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) + +adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property +adTest input expr = property $ +  let env = knownEnv @env +      gradFwd = gradientByForward knownEnv expr input +      gradCHAD = gradientByCHAD' knownEnv expr input +      scFwd = envScalars env gradFwd +      scCHAD = envScalars env gradCHAD +  in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd +  where +    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] +    envScalars SNil SNil = [] +    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs + +tests :: IO Bool +tests = checkParallel $ Group "AD" +  [("id", adTest (Value 42.0))] +  main :: IO () -main = return () +main = defaultMain [tests] | 
