From 5a282fa0256d75dd310014fac20949ef56946053 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 14 Oct 2024 12:20:49 +0200
Subject: More towards test suite

---
 test/Main.hs | 28 ++++++++++++++++++++++++++--
 1 file changed, 26 insertions(+), 2 deletions(-)

(limited to 'test')

diff --git a/test/Main.hs b/test/Main.hs
index 39415bb..045ac1c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,13 +1,16 @@
 {-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE PolyKinds #-}
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE TypeFamilies #-}
 {-# LANGUAGE TypeOperators #-}
 {-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE LambdaCase #-}
 module Main where
 
 import Data.Bifunctor
+import Hedgehog
+import Hedgehog.Main
 
 import Array
 import AST
@@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i
 gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
 gradientByForward env term input = drevByFwd env term input 1.0
 
+closeIsh :: Double -> Double -> Bool
+closeIsh a b =
+  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
+
+adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property
+adTest input expr = property $
+  let env = knownEnv @env
+      gradFwd = gradientByForward knownEnv expr input
+      gradCHAD = gradientByCHAD' knownEnv expr input
+      scFwd = envScalars env gradFwd
+      scCHAD = envScalars env gradCHAD
+  in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
+  where
+    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
+    envScalars SNil SNil = []
+    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+
+tests :: IO Bool
+tests = checkParallel $ Group "AD"
+  [("id", adTest (Value 42.0))]
+
 main :: IO ()
-main = return ()
+main = defaultMain [tests]
-- 
cgit v1.2.3-70-g09d2