diff options
-rw-r--r-- | chad-fast.cabal | 5 | ||||
-rw-r--r-- | src/ForwardAD.hs | 15 | ||||
-rw-r--r-- | test/Main.hs | 28 |
3 files changed, 45 insertions, 3 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index ad949c6..3a0de52 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -57,6 +57,9 @@ test-suite example test-suite test type: exitcode-stdio-1.0 main-is: test/Main.hs - build-depends: base, chad-fast + build-depends: + chad-fast, + base, + hedgehog, default-language: Haskell2010 ghc-options: -Wall -threaded diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 6d53b48..86d2fb0 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -67,6 +67,21 @@ zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars :: STy t -> Rep (Tan t) -> [Double] +tanScalars STNil () = [] +tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y +tanScalars (STEither a _) (Left x) = tanScalars a x +tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STMaybe _) Nothing = [] +tanScalars (STMaybe t) (Just x) = tanScalars t x +tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STScal STI32) _ = [] +tanScalars (STScal STI64) _ = [] +tanScalars (STScal STF32) x = [realToFrac x] +tanScalars (STScal STF64) x = [x] +tanScalars (STScal STBool) _ = [] +tanScalars STAccum{} _ = error "Accumulators not allowed in input program" + unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) unzipDN STNil _ = ((), ()) unzipDN (STPair a b) (d1, d2) = diff --git a/test/Main.hs b/test/Main.hs index 39415bb..045ac1c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,13 +1,16 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE LambdaCase #-} module Main where import Data.Bifunctor +import Hedgehog +import Hedgehog.Main import Array import AST @@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 +closeIsh :: Double -> Double -> Bool +closeIsh a b = + abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) + +adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property +adTest input expr = property $ + let env = knownEnv @env + gradFwd = gradientByForward knownEnv expr input + gradCHAD = gradientByCHAD' knownEnv expr input + scFwd = envScalars env gradFwd + scCHAD = envScalars env gradCHAD + in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd + where + envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] + envScalars SNil SNil = [] + envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs + +tests :: IO Bool +tests = checkParallel $ Group "AD" + [("id", adTest (Value 42.0))] + main :: IO () -main = return () +main = defaultMain [tests] |