diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 21:46:34 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 21:51:22 +0200 |
commit | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (patch) | |
tree | 64dcb00c9c61ad57177db5ec01c189d74dbc2d4a | |
parent | 5a802da40e5836ee19d46b9a2c771912dbff010e (diff) |
Reorganise test files
-rw-r--r-- | ox-arrays.cabal | 4 | ||||
-rw-r--r-- | test/Gen.hs | 85 | ||||
-rw-r--r-- | test/Main.hs | 151 | ||||
-rw-r--r-- | test/Tests/C.hs | 73 | ||||
-rw-r--r-- | test/Util.hs | 38 |
5 files changed, 202 insertions, 149 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index d0aed82..e53815e 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -35,6 +35,10 @@ library test-suite test type: exitcode-stdio-1.0 main-is: Main.hs + other-modules: + Gen + Tests.C + Util build-depends: ox-arrays, base, diff --git a/test/Gen.hs b/test/Gen.hs new file mode 100644 index 0000000..2d2a30b --- /dev/null +++ b/test/Gen.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Gen where + +import qualified Data.ByteString as BS +import Data.Foldable (toList) +import qualified Data.Vector.Storable as VS +import Foreign +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) +import Data.Array.Nested + +import Hedgehog +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import qualified System.Random as Random + +import Util + + +genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () +genRank k = do + rank <- forAll $ Gen.int (Range.linear 0 8) + TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do + x <- Gen.realFloat (Range.linearFrac 0 1) + return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh + where + go :: Int -> [Int] -> IShR n -> Gen (IShR n) + go _ _ ZSR = return ZSR + go nbag bag (_ :$: sh) = do + idx <- Gen.int (Range.linear 0 (nbag - 1)) + let (dim, bag') = case splitAt idx bag of + (pre, n : post) -> (n, pre ++ post) + _ -> error "unreachable" + (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR sn = do + let n = fromSNat' sn + targetSize <- Gen.int (Range.linear 0 100_000) + let genDims :: SNat m -> Int -> Gen (IShR m) + genDims SZ _ = return ZSR + genDims (SS m) 0 = do + dim <- Gen.int (Range.linear 0 20) + dims <- genDims m 0 + return (dim :$: dims) + genDims (SS m) tgt = do + dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) + ,(2 , return tgt) + ,(4 , return 1) + ,(1 , return 0)] + dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) + return (dim :$: dims) + dims <- genDims sn targetSize + let dimsL = toList dims + maxdim = maximum dimsL + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + shuffleShR (min cap <$> dims) + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do + n <- Gen.int rng + seed <- Gen.resize 99 $ Gen.int Range.linearBounded + let gen0 = Random.mkStdGen seed + (bs, _) = Random.genByteString (8 * n) gen0 + let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) + return $ VS.generate n (f . readW64) + diff --git a/test/Main.hs b/test/Main.hs index b5237e5..7e62641 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,158 +1,11 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE NumericUnderscores #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Main where -import Control.Monad -import qualified Data.Array.RankedS as OR -import qualified Data.ByteString as BS -import Data.Foldable (toList) -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import Foreign -import GHC.TypeLits -import qualified GHC.TypeNats as TN - -import qualified Data.Array.Mixed as X -import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) -import Data.Array.Nested -import qualified Data.Array.Nested.Internal as I - --- test framework stuff -import Hedgehog -import Hedgehog.Internal.Property (forAllT) -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range -import qualified System.Random as Random import Test.Tasty -import Test.Tasty.Hedgehog - --- import Debug.Trace - - --- Returns highest value that satisfies the predicate, or `lo` if none does -binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a -binarySearch div2 = \lo hi f -> case (f lo, f hi) of - (False, _) -> lo - (_, True) -> hi - (_, _ ) -> go lo hi f - where - go lo hi f = -- invariant: f lo && not (f hi) - let mid = lo + div2 (hi - lo) - in if mid `elem` [lo, hi] - then mid - else if f mid then go mid hi f else go lo mid f - -genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () -genRank k = do - rank <- forAll $ Gen.int (Range.linear 0 8) - TN.withSomeSNat (fromIntegral rank) k -genLowBiased :: RealFloat a => (a, a) -> Gen a -genLowBiased (lo, hi) = do - x <- Gen.realFloat (Range.linearFrac 0 1) - return (lo + x * x * x * (hi - lo)) +import qualified Tests.C -shuffleShR :: IShR n -> Gen (IShR n) -shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh - where - go :: Int -> [Int] -> IShR n -> Gen (IShR n) - go _ _ ZSR = return ZSR - go nbag bag (_ :$: sh) = do - idx <- Gen.int (Range.linear 0 (nbag - 1)) - let (dim, bag') = case splitAt idx bag of - (pre, n : post) -> (n, pre ++ post) - _ -> error "unreachable" - (dim :$:) <$> go (nbag - 1) bag' sh - -genShR :: SNat n -> Gen (IShR n) -genShR sn = do - let n = fromSNat' sn - targetSize <- Gen.int (Range.linear 0 100_000) - let genDims :: SNat m -> Int -> Gen (IShR m) - genDims SZ _ = return ZSR - genDims (SS m) 0 = do - dim <- Gen.int (Range.linear 0 20) - dims <- genDims m 0 - return (dim :$: dims) - genDims (SS m) tgt = do - dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) - ,(2 , return tgt) - ,(4 , return 1) - ,(1 , return 0)] - dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) - return (dim :$: dims) - dims <- genDims sn targetSize - let dimsL = toList dims - maxdim = maximum dimsL - cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) - shuffleShR (min cap <$> dims) - -genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) -genStorables rng f = do - n <- Gen.int rng - seed <- Gen.resize 99 $ Gen.int Range.linearBounded - let gen0 = Random.mkStdGen seed - (bs, _) = Random.genByteString (8 * n) gen0 - let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) - return $ VS.generate n (f . readW64) - -orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a -orSumOuter1 (sn@SNat :: SNat n) = - let n = fromSNat' sn - in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) - -rshTail :: ShR (n + 1) i -> ShR n i -rshTail (_ :$: sh) = sh -rshTail ZSR = error "unreachable" main :: IO () main = defaultMain $ testGroup "Tests" - [testGroup "C" - [testGroup "sum" - [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do - -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. - let inrank = SNat @(n + 1) - sh <- forAll $ genShR inrank - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> - genStorables (Range.singleton (product sh)) - (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - -- annotateShow rarr - Refl <- return $ I.lemRankReplicate outrank - let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - let rhs = orSumOuter1 outrank arr - -- annotateShow lhs - -- annotateShow rhs - lhs === rhs - - ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do - -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. - outrank :: SNat n <- return $ SNat @(nm1 + 1) - let inrank = SNat @(n + 1) - sh <- forAll $ do - shtt <- genShR outrankm1 -- nm1 - sht <- shuffleShR (0 :$: shtt) -- n - n <- Gen.int (Range.linear 0 20) - return (n :$: sht) -- n + 1 - guard (any (== 0) (toList (rshTail sh))) - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - let arr = OR.fromList @Double @(n + 1) (toList sh) [] - let rarr = rfromOrthotope inrank arr - Refl <- return $ I.lemRankReplicate outrank - let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - OR.toList lhs === [] - ] - ] - ] + [Tests.C.tests] diff --git a/test/Tests/C.hs b/test/Tests/C.hs new file mode 100644 index 0000000..1041b2a --- /dev/null +++ b/test/Tests/C.hs @@ -0,0 +1,73 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Tests.C where + +import Control.Monad +import qualified Data.Array.RankedS as OR +import Data.Foldable (toList) +import Data.Type.Equality +import Foreign +import GHC.TypeLits + +import qualified Data.Array.Mixed as X +import Data.Array.Nested +import qualified Data.Array.Nested.Internal as I + +import Hedgehog +import Hedgehog.Internal.Property (forAllT) +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +-- import Debug.Trace + +import Gen +import Util + + +tests :: TestTree +tests = testGroup "C" + [testGroup "sum" + [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + -- annotateShow rarr + Refl <- return $ I.lemRankReplicate outrank + let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr + let rhs = orSumOuter1 outrank arr + -- annotateShow lhs + -- annotateShow rhs + lhs === rhs + + ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (any (== 0) (toList (rshTail sh))) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @Double @(n + 1) (toList sh) [] + let rarr = rfromOrthotope inrank arr + Refl <- return $ I.lemRankReplicate outrank + let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr + OR.toList lhs === [] + ] + ] diff --git a/test/Util.hs b/test/Util.hs new file mode 100644 index 0000000..1249bf9 --- /dev/null +++ b/test/Util.hs @@ -0,0 +1,38 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Util where + +import qualified Data.Array.RankedS as OR +import GHC.TypeLits + +import Data.Array.Mixed (fromSNat') +import Data.Array.Nested + + +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of + (False, _) -> lo + (_, True) -> hi + (_, _ ) -> go lo hi f + where + go lo hi f = -- invariant: f lo && not (f hi) + let mid = lo + div2 (hi - lo) + in if mid `elem` [lo, hi] + then mid + else if f mid then go mid hi f else go lo mid f + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = + let n = fromSNat' sn + in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +rshTail :: ShR (n + 1) i -> ShR n i +rshTail (_ :$: sh) = sh +rshTail ZSR = error "unreachable" |