diff options
| -rw-r--r-- | chad-fast.cabal | 5 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 15 | ||||
| -rw-r--r-- | test/Main.hs | 28 | 
3 files changed, 45 insertions, 3 deletions
| diff --git a/chad-fast.cabal b/chad-fast.cabal index ad949c6..3a0de52 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -57,6 +57,9 @@ test-suite example  test-suite test    type: exitcode-stdio-1.0    main-is: test/Main.hs -  build-depends: base, chad-fast +  build-depends: +    chad-fast, +    base, +    hedgehog,    default-language: Haskell2010    ghc-options: -Wall -threaded diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 6d53b48..86d2fb0 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -67,6 +67,21 @@ zeroTan (STScal STF64) _ = 0.0  zeroTan (STScal STBool) _ = ()  zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars :: STy t -> Rep (Tan t) -> [Double] +tanScalars STNil () = [] +tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y +tanScalars (STEither a _) (Left x) = tanScalars a x +tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STMaybe _) Nothing = [] +tanScalars (STMaybe t) (Just x) = tanScalars t x +tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STScal STI32) _ = [] +tanScalars (STScal STI64) _ = [] +tanScalars (STScal STF32) x = [realToFrac x] +tanScalars (STScal STF64) x = [x] +tanScalars (STScal STBool) _ = [] +tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +  unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))  unzipDN STNil _ = ((), ())  unzipDN (STPair a b) (d1, d2) = diff --git a/test/Main.hs b/test/Main.hs index 39415bb..045ac1c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,13 +1,16 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE LambdaCase #-}  module Main where  import Data.Bifunctor +import Hedgehog +import Hedgehog.Main  import Array  import AST @@ -95,5 +98,26 @@ gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term i  gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)  gradientByForward env term input = drevByFwd env term input 1.0 +closeIsh :: Double -> Double -> Bool +closeIsh a b = +  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) + +adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property +adTest input expr = property $ +  let env = knownEnv @env +      gradFwd = gradientByForward knownEnv expr input +      gradCHAD = gradientByCHAD' knownEnv expr input +      scFwd = envScalars env gradFwd +      scCHAD = envScalars env gradCHAD +  in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd +  where +    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] +    envScalars SNil SNil = [] +    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs + +tests :: IO Bool +tests = checkParallel $ Group "AD" +  [("id", adTest (Value 42.0))] +  main :: IO () -main = return () +main = defaultMain [tests] | 
