diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-14 12:20:49 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-14 12:20:49 +0200 |
commit | 5a282fa0256d75dd310014fac20949ef56946053 (patch) | |
tree | 84c923ceddc225761ad702f42962d949e92e7a69 /test | |
parent | 72eddb67bb6f048fc2076184be3a32169026a832 (diff) |
More towards test suite
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 28 |
1 files changed, 26 insertions, 2 deletions
diff --git a/test/Main.hs b/test/Main.hs index 39415bb..045ac1c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,13 +1,16 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE LambdaCase #-} module Main where import Data.Bifunctor +import Hedgehog +import Hedgehog.Main import Array import AST @@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 +closeIsh :: Double -> Double -> Bool +closeIsh a b = + abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) + +adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property +adTest input expr = property $ + let env = knownEnv @env + gradFwd = gradientByForward knownEnv expr input + gradCHAD = gradientByCHAD' knownEnv expr input + scFwd = envScalars env gradFwd + scCHAD = envScalars env gradCHAD + in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd + where + envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] + envScalars SNil SNil = [] + envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs + +tests :: IO Bool +tests = checkParallel $ Group "AD" + [("id", adTest (Value 42.0))] + main :: IO () -main = return () +main = defaultMain [tests] |