summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-15 13:44:43 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-15 13:44:43 +0200
commit6fb15f0b632d3651cc3e6089c20b07b009b578eb (patch)
tree8513de490e152dc0cdb324fedf2324e6f04d1071 /test
parent5a282fa0256d75dd310014fac20949ef56946053 (diff)
We can differentiate id
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs83
1 files changed, 78 insertions, 5 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 045ac1c..34ab5af 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -9,7 +10,10 @@
module Main where
import Data.Bifunctor
+import Data.List (intercalate)
import Hedgehog
+import qualified Hedgehog.Gen as Gen
+import qualified Hedgehog.Range as Range
import Hedgehog.Main
import Array
@@ -20,6 +24,7 @@ import Data
import ForwardAD
import Interpreter
import Interpreter.Rep
+import Language
type family MapMerge env where
@@ -102,14 +107,82 @@ closeIsh :: Double -> Double -> Bool
closeIsh a b =
abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
-adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property
-adTest input expr = property $
+genShape :: SNat n -> Gen (Shape n)
+genShape = \n -> do
+ sh <- genShapeNaive n
+ let sz = shapeSize sh
+ factor = sz `div` 100 + 1
+ return (shapeDiv sh factor)
+ where
+ genShapeNaive :: SNat n -> Gen (Shape n)
+ genShapeNaive SZ = return ShNil
+ genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)
+
+ shapeDiv :: Shape n -> Int -> Shape n
+ shapeDiv ShNil _ = ShNil
+ shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
+
+genValue :: STy a -> Gen (Value a)
+genValue = \case
+ STNil -> return (Value ())
+ STPair a b -> lv2 (,) <$> genValue a <*> genValue b
+ STEither a b -> Gen.choice [lv1 Left <$> genValue a
+ ,lv1 Right <$> genValue b]
+ STMaybe t -> Gen.choice [return (Value Nothing)
+ ,lv1 Just <$> genValue t]
+ STArr n t -> do
+ sh <- genShape n
+ Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t)
+ STScal sty -> case sty of
+ STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
+ STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
+ STBool -> Gen.choice [return (Value False), return (Value True)]
+ STAccum{} -> error "Cannot generate inputs for accumulators"
+ where
+ lv1 :: (Rep a -> Rep b) -> Value a -> Value b
+ lv1 f (Value x) = Value (f x)
+
+ lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
+ lv2 f (Value x) (Value y) = Value (f x y)
+
+genEnv :: SList STy env -> Gen (SList Value env)
+genEnv SNil = return SNil
+genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
+
+showValue :: Int -> STy t -> Rep t -> ShowS
+showValue _ STNil () = showString "()"
+showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
+showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
+showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
+showValue _ (STMaybe _) Nothing = showString "Nothing"
+showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
+showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve
+showValue _ (STScal sty) x = case sty of
+ STF32 -> shows x
+ STF64 -> shows x
+ STI32 -> shows x
+ STI64 -> shows x
+ STBool -> shows x
+showValue _ STAccum{} _ = error "Cannot show accumulators"
+
+showEnv :: SList STy env -> SList Value env -> String
+showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
+ where
+ showEntries :: SList STy env -> SList Value env -> [String]
+ showEntries SNil SNil = []
+ showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
+
+adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
+adTest expr = property $ do
let env = knownEnv @env
- gradFwd = gradientByForward knownEnv expr input
+ input <- forAllWith (showEnv env) $ genEnv env
+ let gradFwd = gradientByForward knownEnv expr input
gradCHAD = gradientByCHAD' knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
- in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
@@ -117,7 +190,7 @@ adTest input expr = property $
tests :: IO Bool
tests = checkParallel $ Group "AD"
- [("id", adTest (Value 42.0))]
+ [("id", adTest $ fromNamed $ lambda #x $ body $ #x)]
main :: IO ()
main = defaultMain [tests]