diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-15 13:44:43 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-15 13:44:43 +0200 |
commit | 6fb15f0b632d3651cc3e6089c20b07b009b578eb (patch) | |
tree | 8513de490e152dc0cdb324fedf2324e6f04d1071 /test | |
parent | 5a282fa0256d75dd310014fac20949ef56946053 (diff) |
We can differentiate id
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 83 |
1 files changed, 78 insertions, 5 deletions
diff --git a/test/Main.hs b/test/Main.hs index 045ac1c..34ab5af 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -9,7 +10,10 @@ module Main where import Data.Bifunctor +import Data.List (intercalate) import Hedgehog +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range import Hedgehog.Main import Array @@ -20,6 +24,7 @@ import Data import ForwardAD import Interpreter import Interpreter.Rep +import Language type family MapMerge env where @@ -102,14 +107,82 @@ closeIsh :: Double -> Double -> Bool closeIsh a b = abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) -adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property -adTest input expr = property $ +genShape :: SNat n -> Gen (Shape n) +genShape = \n -> do + sh <- genShapeNaive n + let sz = shapeSize sh + factor = sz `div` 100 + 1 + return (shapeDiv sh factor) + where + genShapeNaive :: SNat n -> Gen (Shape n) + genShapeNaive SZ = return ShNil + genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10) + + shapeDiv :: Shape n -> Int -> Shape n + shapeDiv ShNil _ = ShNil + shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) + +genValue :: STy a -> Gen (Value a) +genValue = \case + STNil -> return (Value ()) + STPair a b -> lv2 (,) <$> genValue a <*> genValue b + STEither a b -> Gen.choice [lv1 Left <$> genValue a + ,lv1 Right <$> genValue b] + STMaybe t -> Gen.choice [return (Value Nothing) + ,lv1 Just <$> genValue t] + STArr n t -> do + sh <- genShape n + Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t) + STScal sty -> case sty of + STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) + STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) + STBool -> Gen.choice [return (Value False), return (Value True)] + STAccum{} -> error "Cannot generate inputs for accumulators" + where + lv1 :: (Rep a -> Rep b) -> Value a -> Value b + lv1 f (Value x) = Value (f x) + + lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c + lv2 f (Value x) (Value y) = Value (f x y) + +genEnv :: SList STy env -> Gen (SList Value env) +genEnv SNil = return SNil +genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env + +showValue :: Int -> STy t -> Rep t -> ShowS +showValue _ STNil () = showString "()" +showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y +showValue _ (STMaybe _) Nothing = showString "Nothing" +showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x +showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve +showValue _ (STScal sty) x = case sty of + STF32 -> shows x + STF64 -> shows x + STI32 -> shows x + STI64 -> shows x + STBool -> shows x +showValue _ STAccum{} _ = error "Cannot show accumulators" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" + where + showEntries :: SList STy env -> SList Value env -> [String] + showEntries SNil SNil = [] + showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs + +adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property +adTest expr = property $ do let env = knownEnv @env - gradFwd = gradientByForward knownEnv expr input + input <- forAllWith (showEnv env) $ genEnv env + let gradFwd = gradientByForward knownEnv expr input gradCHAD = gradientByCHAD' knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD - in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd + diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] @@ -117,7 +190,7 @@ adTest input expr = property $ tests :: IO Bool tests = checkParallel $ Group "AD" - [("id", adTest (Value 42.0))] + [("id", adTest $ fromNamed $ lambda #x $ body $ #x)] main :: IO () main = defaultMain [tests] |