diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-24 09:38:12 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-24 09:38:12 +0200 | 
| commit | 7afc48ac656162972f780740c675d87d15b2fe1d (patch) | |
| tree | 1c369812da61680e88ae527e6c259d9327ae71a2 | |
| parent | 97ea9bbf91c2e42718b734b4c025eb101ea8218d (diff) | |
Test C sum with random inputs
| -rw-r--r-- | ox-arrays.cabal | 5 | ||||
| -rw-r--r-- | test/Main.hs | 75 | 
2 files changed, 68 insertions, 12 deletions
| diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 94a4529..1eea23d 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -34,11 +34,14 @@ test-suite test    build-depends:      ox-arrays,      base, +    bytestring,      ghc-typelits-knownnat,      hedgehog,      orthotope, +    random,      tasty, -    tasty-hedgehog +    tasty-hedgehog, +    vector    hs-source-dirs: test    default-language: Haskell2010    ghc-options: -Wall diff --git a/test/Main.hs b/test/Main.hs index 002c606..dd59586 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,17 +1,22 @@ +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE NumericUnderscores #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DataKinds #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Main where +import Control.Monad  import qualified Data.Array.RankedS as OR +import qualified Data.ByteString as BS  import Data.Foldable (toList)  import Data.Type.Equality +import qualified Data.Vector.Storable as VS +import Foreign  import GHC.TypeLits  import qualified GHC.TypeNats as TN @@ -22,14 +27,29 @@ import qualified Data.Array.Nested.Internal as I  -- test framework stuff  import Hedgehog +import Hedgehog.Internal.Property (forAllT)  import qualified Hedgehog.Gen as Gen  import qualified Hedgehog.Range as Range +import qualified System.Random as Random  import Test.Tasty  import Test.Tasty.Hedgehog  import Debug.Trace +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of +  (False, _) -> lo +  (_, True) -> hi +  (_, _ ) -> go lo hi f +  where +    go lo hi f =  -- invariant: f lo && not (f hi) +      let mid = lo + div2 (hi - lo) +      in if mid `elem` [lo, hi] +           then mid +           else if f mid then go mid hi f else go lo mid f +  genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()  genRank k = do    rank <- forAll $ Gen.int (Range.linear 0 8) @@ -55,7 +75,7 @@ shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh  genShR :: SNat n -> Gen (IShR n)  genShR sn = do    let n = fromSNat' sn -  targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n)) +  targetSize <- Gen.int (Range.linear 0 100_000)    let genDims :: SNat m -> Int -> Gen (IShR m)        genDims SZ _ = return ZSR        genDims (SS m) 0 = do @@ -69,7 +89,20 @@ genShR sn = do                               ,(1     , return 0)]          dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)          return (dim :$: dims) -  shuffleShR =<< genDims sn targetSize +  dims <- genDims sn targetSize +  let dimsL = toList dims +      maxdim = maximum dimsL +      cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) +  shuffleShR (min cap <$> dims) + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do +  n <- Gen.int rng +  seed <- Gen.resize 99 $ Gen.int Range.linearBounded +  let gen0 = Random.mkStdGen seed +      (bs, _) = Random.genByteString (8 * n) gen0 +  let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) +  return $ VS.generate n (f . readW64)  orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a  orSumOuter1 (sn@SNat :: SNat n) = @@ -81,20 +114,40 @@ main = defaultMain $    testGroup "Tests"      [testGroup "C"        [testGroup "sum" -        [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do +        [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do +          -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.            let inrank = SNat @(n + 1)            sh <- forAll $ genShR inrank -          arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$> -                   Gen.list (Range.singleton (product sh)) -                            (Gen.realFloat (Range.linearFrac @Double 0 1)) +          -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +          guard (all (> 0) (tail (toList sh)))  -- only constrain the tail +          arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> +                   genStorables (Range.singleton (product sh)) +                                (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))            let rarr = rfromOrthotope inrank arr -          annotateShow rarr +          -- annotateShow rarr            Refl <- return $ I.lemRankReplicate outrank            let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr            let rhs = orSumOuter1 outrank arr -          annotateShow lhs -          annotateShow rhs +          -- annotateShow lhs +          -- annotateShow rhs            lhs === rhs + +        ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do +          -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. +          outrank :: SNat n <- return $ SNat @(nm1 + 1) +          let inrank = SNat @(n + 1) +          sh <- forAll $ do +            shtt <- genShR outrankm1  -- nm1 +            sht <- shuffleShR (0 :$: shtt)  -- n +            n <- Gen.int (Range.linear 0 20) +            return (n :$: sht)  -- n + 1 +          guard (any (== 0) (tail (toList sh))) +          -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +          let arr = OR.fromList @Double @(n + 1) (toList sh) [] +          let rarr = rfromOrthotope inrank arr +          Refl <- return $ I.lemRankReplicate outrank +          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr +          OR.toList lhs === []          ]        ]      ] | 
