diff options
| -rw-r--r-- | ox-arrays.cabal | 4 | ||||
| -rw-r--r-- | test/Gen.hs | 85 | ||||
| -rw-r--r-- | test/Main.hs | 151 | ||||
| -rw-r--r-- | test/Tests/C.hs | 73 | ||||
| -rw-r--r-- | test/Util.hs | 38 | 
5 files changed, 202 insertions, 149 deletions
| diff --git a/ox-arrays.cabal b/ox-arrays.cabal index d0aed82..e53815e 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -35,6 +35,10 @@ library  test-suite test    type: exitcode-stdio-1.0    main-is: Main.hs +  other-modules: +    Gen +    Tests.C +    Util    build-depends:      ox-arrays,      base, diff --git a/test/Gen.hs b/test/Gen.hs new file mode 100644 index 0000000..2d2a30b --- /dev/null +++ b/test/Gen.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Gen where + +import qualified Data.ByteString as BS +import Data.Foldable (toList) +import qualified Data.Vector.Storable as VS +import Foreign +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) +import Data.Array.Nested + +import Hedgehog +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import qualified System.Random as Random + +import Util + + +genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () +genRank k = do +  rank <- forAll $ Gen.int (Range.linear 0 8) +  TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do +  x <- Gen.realFloat (Range.linearFrac 0 1) +  return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh +  where +    go :: Int -> [Int] -> IShR n -> Gen (IShR n) +    go _    _   ZSR = return ZSR +    go nbag bag (_ :$: sh) = do +      idx <- Gen.int (Range.linear 0 (nbag - 1)) +      let (dim, bag') = case splitAt idx bag of +                          (pre, n : post) -> (n, pre ++ post) +                          _ -> error "unreachable" +      (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR sn = do +  let n = fromSNat' sn +  targetSize <- Gen.int (Range.linear 0 100_000) +  let genDims :: SNat m -> Int -> Gen (IShR m) +      genDims SZ _ = return ZSR +      genDims (SS m) 0 = do +        dim <- Gen.int (Range.linear 0 20) +        dims <- genDims m 0 +        return (dim :$: dims) +      genDims (SS m) tgt = do +        dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) +                             ,(2     , return tgt) +                             ,(4     , return 1) +                             ,(1     , return 0)] +        dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) +        return (dim :$: dims) +  dims <- genDims sn targetSize +  let dimsL = toList dims +      maxdim = maximum dimsL +      cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) +  shuffleShR (min cap <$> dims) + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do +  n <- Gen.int rng +  seed <- Gen.resize 99 $ Gen.int Range.linearBounded +  let gen0 = Random.mkStdGen seed +      (bs, _) = Random.genByteString (8 * n) gen0 +  let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) +  return $ VS.generate n (f . readW64) + diff --git a/test/Main.hs b/test/Main.hs index b5237e5..7e62641 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,158 +1,11 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE NumericUnderscores #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Main where -import Control.Monad -import qualified Data.Array.RankedS as OR -import qualified Data.ByteString as BS -import Data.Foldable (toList) -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import Foreign -import GHC.TypeLits -import qualified GHC.TypeNats as TN - -import qualified Data.Array.Mixed as X -import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) -import Data.Array.Nested -import qualified Data.Array.Nested.Internal as I - --- test framework stuff -import Hedgehog -import Hedgehog.Internal.Property (forAllT) -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range -import qualified System.Random as Random  import Test.Tasty -import Test.Tasty.Hedgehog - --- import Debug.Trace - - --- Returns highest value that satisfies the predicate, or `lo` if none does -binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a -binarySearch div2 = \lo hi f -> case (f lo, f hi) of -  (False, _) -> lo -  (_, True) -> hi -  (_, _ ) -> go lo hi f -  where -    go lo hi f =  -- invariant: f lo && not (f hi) -      let mid = lo + div2 (hi - lo) -      in if mid `elem` [lo, hi] -           then mid -           else if f mid then go mid hi f else go lo mid f - -genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () -genRank k = do -  rank <- forAll $ Gen.int (Range.linear 0 8) -  TN.withSomeSNat (fromIntegral rank) k -genLowBiased :: RealFloat a => (a, a) -> Gen a -genLowBiased (lo, hi) = do -  x <- Gen.realFloat (Range.linearFrac 0 1) -  return (lo + x * x * x * (hi - lo)) +import qualified Tests.C -shuffleShR :: IShR n -> Gen (IShR n) -shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh -  where -    go :: Int -> [Int] -> IShR n -> Gen (IShR n) -    go _    _   ZSR = return ZSR -    go nbag bag (_ :$: sh) = do -      idx <- Gen.int (Range.linear 0 (nbag - 1)) -      let (dim, bag') = case splitAt idx bag of -                          (pre, n : post) -> (n, pre ++ post) -                          _ -> error "unreachable" -      (dim :$:) <$> go (nbag - 1) bag' sh - -genShR :: SNat n -> Gen (IShR n) -genShR sn = do -  let n = fromSNat' sn -  targetSize <- Gen.int (Range.linear 0 100_000) -  let genDims :: SNat m -> Int -> Gen (IShR m) -      genDims SZ _ = return ZSR -      genDims (SS m) 0 = do -        dim <- Gen.int (Range.linear 0 20) -        dims <- genDims m 0 -        return (dim :$: dims) -      genDims (SS m) tgt = do -        dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) -                             ,(2     , return tgt) -                             ,(4     , return 1) -                             ,(1     , return 0)] -        dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) -        return (dim :$: dims) -  dims <- genDims sn targetSize -  let dimsL = toList dims -      maxdim = maximum dimsL -      cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) -  shuffleShR (min cap <$> dims) - -genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) -genStorables rng f = do -  n <- Gen.int rng -  seed <- Gen.resize 99 $ Gen.int Range.linearBounded -  let gen0 = Random.mkStdGen seed -      (bs, _) = Random.genByteString (8 * n) gen0 -  let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) -  return $ VS.generate n (f . readW64) - -orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a -orSumOuter1 (sn@SNat :: SNat n) = -  let n = fromSNat' sn -  in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) - -rshTail :: ShR (n + 1) i -> ShR n i -rshTail (_ :$: sh) = sh -rshTail ZSR = error "unreachable"  main :: IO ()  main = defaultMain $    testGroup "Tests" -    [testGroup "C" -      [testGroup "sum" -        [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do -          -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. -          let inrank = SNat @(n + 1) -          sh <- forAll $ genShR inrank -          -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -          guard (all (> 0) (toList (rshTail sh)))  -- only constrain the tail -          arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> -                   genStorables (Range.singleton (product sh)) -                                (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) -          let rarr = rfromOrthotope inrank arr -          -- annotateShow rarr -          Refl <- return $ I.lemRankReplicate outrank -          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr -          let rhs = orSumOuter1 outrank arr -          -- annotateShow lhs -          -- annotateShow rhs -          lhs === rhs - -        ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do -          -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. -          outrank :: SNat n <- return $ SNat @(nm1 + 1) -          let inrank = SNat @(n + 1) -          sh <- forAll $ do -            shtt <- genShR outrankm1  -- nm1 -            sht <- shuffleShR (0 :$: shtt)  -- n -            n <- Gen.int (Range.linear 0 20) -            return (n :$: sht)  -- n + 1 -          guard (any (== 0) (toList (rshTail sh))) -          -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -          let arr = OR.fromList @Double @(n + 1) (toList sh) [] -          let rarr = rfromOrthotope inrank arr -          Refl <- return $ I.lemRankReplicate outrank -          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr -          OR.toList lhs === [] -        ] -      ] -    ] +    [Tests.C.tests] diff --git a/test/Tests/C.hs b/test/Tests/C.hs new file mode 100644 index 0000000..1041b2a --- /dev/null +++ b/test/Tests/C.hs @@ -0,0 +1,73 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Tests.C where + +import Control.Monad +import qualified Data.Array.RankedS as OR +import Data.Foldable (toList) +import Data.Type.Equality +import Foreign +import GHC.TypeLits + +import qualified Data.Array.Mixed as X +import Data.Array.Nested +import qualified Data.Array.Nested.Internal as I + +import Hedgehog +import Hedgehog.Internal.Property (forAllT) +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +-- import Debug.Trace + +import Gen +import Util + + +tests :: TestTree +tests = testGroup "C" +  [testGroup "sum" +    [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do +      -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. +      let inrank = SNat @(n + 1) +      sh <- forAll $ genShR inrank +      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +      guard (all (> 0) (toList (rshTail sh)))  -- only constrain the tail +      arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> +               genStorables (Range.singleton (product sh)) +                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +      let rarr = rfromOrthotope inrank arr +      -- annotateShow rarr +      Refl <- return $ I.lemRankReplicate outrank +      let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr +      let rhs = orSumOuter1 outrank arr +      -- annotateShow lhs +      -- annotateShow rhs +      lhs === rhs + +    ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do +      -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. +      outrank :: SNat n <- return $ SNat @(nm1 + 1) +      let inrank = SNat @(n + 1) +      sh <- forAll $ do +        shtt <- genShR outrankm1  -- nm1 +        sht <- shuffleShR (0 :$: shtt)  -- n +        n <- Gen.int (Range.linear 0 20) +        return (n :$: sht)  -- n + 1 +      guard (any (== 0) (toList (rshTail sh))) +      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +      let arr = OR.fromList @Double @(n + 1) (toList sh) [] +      let rarr = rfromOrthotope inrank arr +      Refl <- return $ I.lemRankReplicate outrank +      let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr +      OR.toList lhs === [] +    ] +  ] diff --git a/test/Util.hs b/test/Util.hs new file mode 100644 index 0000000..1249bf9 --- /dev/null +++ b/test/Util.hs @@ -0,0 +1,38 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Util where + +import qualified Data.Array.RankedS as OR +import GHC.TypeLits + +import Data.Array.Mixed (fromSNat') +import Data.Array.Nested + + +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of +  (False, _) -> lo +  (_, True) -> hi +  (_, _ ) -> go lo hi f +  where +    go lo hi f =  -- invariant: f lo && not (f hi) +      let mid = lo + div2 (hi - lo) +      in if mid `elem` [lo, hi] +           then mid +           else if f mid then go mid hi f else go lo mid f + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = +  let n = fromSNat' sn +  in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +rshTail :: ShR (n + 1) i -> ShR n i +rshTail (_ :$: sh) = sh +rshTail ZSR = error "unreachable" | 
