summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal5
-rw-r--r--src/ForwardAD.hs15
-rw-r--r--test/Main.hs28
3 files changed, 45 insertions, 3 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ad949c6..3a0de52 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -57,6 +57,9 @@ test-suite example
test-suite test
type: exitcode-stdio-1.0
main-is: test/Main.hs
- build-depends: base, chad-fast
+ build-depends:
+ chad-fast,
+ base,
+ hedgehog,
default-language: Haskell2010
ghc-options: -Wall -threaded
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 6d53b48..86d2fb0 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -67,6 +67,21 @@ zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars :: STy t -> Rep (Tan t) -> [Double]
+tanScalars STNil () = []
+tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
+tanScalars (STEither a _) (Left x) = tanScalars a x
+tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STMaybe _) Nothing = []
+tanScalars (STMaybe t) (Just x) = tanScalars t x
+tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
+tanScalars (STScal STI32) _ = []
+tanScalars (STScal STI64) _ = []
+tanScalars (STScal STF32) x = [realToFrac x]
+tanScalars (STScal STF64) x = [x]
+tanScalars (STScal STBool) _ = []
+tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+
unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =
diff --git a/test/Main.hs b/test/Main.hs
index 39415bb..045ac1c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,13 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE LambdaCase #-}
module Main where
import Data.Bifunctor
+import Hedgehog
+import Hedgehog.Main
import Array
import AST
@@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
+closeIsh :: Double -> Double -> Bool
+closeIsh a b =
+ abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
+
+adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property
+adTest input expr = property $
+ let env = knownEnv @env
+ gradFwd = gradientByForward knownEnv expr input
+ gradCHAD = gradientByCHAD' knownEnv expr input
+ scFwd = envScalars env gradFwd
+ scCHAD = envScalars env gradCHAD
+ in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
+ where
+ envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
+ envScalars SNil SNil = []
+ envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+
+tests :: IO Bool
+tests = checkParallel $ Group "AD"
+ [("id", adTest (Value 42.0))]
+
main :: IO ()
-main = return ()
+main = defaultMain [tests]