aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-28 21:46:34 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-28 21:51:22 +0200
commitd8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (patch)
tree64dcb00c9c61ad57177db5ec01c189d74dbc2d4a /test
parent5a802da40e5836ee19d46b9a2c771912dbff010e (diff)
Reorganise test files
Diffstat (limited to 'test')
-rw-r--r--test/Gen.hs85
-rw-r--r--test/Main.hs151
-rw-r--r--test/Tests/C.hs73
-rw-r--r--test/Util.hs38
4 files changed, 198 insertions, 149 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
new file mode 100644
index 0000000..2d2a30b
--- /dev/null
+++ b/test/Gen.hs
@@ -0,0 +1,85 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Gen where
+
+import qualified Data.ByteString as BS
+import Data.Foldable (toList)
+import qualified Data.Vector.Storable as VS
+import Foreign
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+
+import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
+import Data.Array.Nested
+
+import Hedgehog
+import qualified Hedgehog.Gen as Gen
+import qualified Hedgehog.Range as Range
+import qualified System.Random as Random
+
+import Util
+
+
+genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
+genRank k = do
+ rank <- forAll $ Gen.int (Range.linear 0 8)
+ TN.withSomeSNat (fromIntegral rank) k
+
+genLowBiased :: RealFloat a => (a, a) -> Gen a
+genLowBiased (lo, hi) = do
+ x <- Gen.realFloat (Range.linearFrac 0 1)
+ return (lo + x * x * x * (hi - lo))
+
+shuffleShR :: IShR n -> Gen (IShR n)
+shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
+ where
+ go :: Int -> [Int] -> IShR n -> Gen (IShR n)
+ go _ _ ZSR = return ZSR
+ go nbag bag (_ :$: sh) = do
+ idx <- Gen.int (Range.linear 0 (nbag - 1))
+ let (dim, bag') = case splitAt idx bag of
+ (pre, n : post) -> (n, pre ++ post)
+ _ -> error "unreachable"
+ (dim :$:) <$> go (nbag - 1) bag' sh
+
+genShR :: SNat n -> Gen (IShR n)
+genShR sn = do
+ let n = fromSNat' sn
+ targetSize <- Gen.int (Range.linear 0 100_000)
+ let genDims :: SNat m -> Int -> Gen (IShR m)
+ genDims SZ _ = return ZSR
+ genDims (SS m) 0 = do
+ dim <- Gen.int (Range.linear 0 20)
+ dims <- genDims m 0
+ return (dim :$: dims)
+ genDims (SS m) tgt = do
+ dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
+ ,(2 , return tgt)
+ ,(4 , return 1)
+ ,(1 , return 0)]
+ dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
+ return (dim :$: dims)
+ dims <- genDims sn targetSize
+ let dimsL = toList dims
+ maxdim = maximum dimsL
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ shuffleShR (min cap <$> dims)
+
+genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
+genStorables rng f = do
+ n <- Gen.int rng
+ seed <- Gen.resize 99 $ Gen.int Range.linearBounded
+ let gen0 = Random.mkStdGen seed
+ (bs, _) = Random.genByteString (8 * n) gen0
+ let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
+ return $ VS.generate n (f . readW64)
+
diff --git a/test/Main.hs b/test/Main.hs
index b5237e5..7e62641 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,158 +1,11 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE NumericUnderscores #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where
-import Control.Monad
-import qualified Data.Array.RankedS as OR
-import qualified Data.ByteString as BS
-import Data.Foldable (toList)
-import Data.Type.Equality
-import qualified Data.Vector.Storable as VS
-import Foreign
-import GHC.TypeLits
-import qualified GHC.TypeNats as TN
-
-import qualified Data.Array.Mixed as X
-import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
-import Data.Array.Nested
-import qualified Data.Array.Nested.Internal as I
-
--- test framework stuff
-import Hedgehog
-import Hedgehog.Internal.Property (forAllT)
-import qualified Hedgehog.Gen as Gen
-import qualified Hedgehog.Range as Range
-import qualified System.Random as Random
import Test.Tasty
-import Test.Tasty.Hedgehog
-
--- import Debug.Trace
-
-
--- Returns highest value that satisfies the predicate, or `lo` if none does
-binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
-binarySearch div2 = \lo hi f -> case (f lo, f hi) of
- (False, _) -> lo
- (_, True) -> hi
- (_, _ ) -> go lo hi f
- where
- go lo hi f = -- invariant: f lo && not (f hi)
- let mid = lo + div2 (hi - lo)
- in if mid `elem` [lo, hi]
- then mid
- else if f mid then go mid hi f else go lo mid f
-
-genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
-genRank k = do
- rank <- forAll $ Gen.int (Range.linear 0 8)
- TN.withSomeSNat (fromIntegral rank) k
-genLowBiased :: RealFloat a => (a, a) -> Gen a
-genLowBiased (lo, hi) = do
- x <- Gen.realFloat (Range.linearFrac 0 1)
- return (lo + x * x * x * (hi - lo))
+import qualified Tests.C
-shuffleShR :: IShR n -> Gen (IShR n)
-shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
- where
- go :: Int -> [Int] -> IShR n -> Gen (IShR n)
- go _ _ ZSR = return ZSR
- go nbag bag (_ :$: sh) = do
- idx <- Gen.int (Range.linear 0 (nbag - 1))
- let (dim, bag') = case splitAt idx bag of
- (pre, n : post) -> (n, pre ++ post)
- _ -> error "unreachable"
- (dim :$:) <$> go (nbag - 1) bag' sh
-
-genShR :: SNat n -> Gen (IShR n)
-genShR sn = do
- let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 100_000)
- let genDims :: SNat m -> Int -> Gen (IShR m)
- genDims SZ _ = return ZSR
- genDims (SS m) 0 = do
- dim <- Gen.int (Range.linear 0 20)
- dims <- genDims m 0
- return (dim :$: dims)
- genDims (SS m) tgt = do
- dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
- ,(2 , return tgt)
- ,(4 , return 1)
- ,(1 , return 0)]
- dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
- return (dim :$: dims)
- dims <- genDims sn targetSize
- let dimsL = toList dims
- maxdim = maximum dimsL
- cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
- shuffleShR (min cap <$> dims)
-
-genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
-genStorables rng f = do
- n <- Gen.int rng
- seed <- Gen.resize 99 $ Gen.int Range.linearBounded
- let gen0 = Random.mkStdGen seed
- (bs, _) = Random.genByteString (8 * n) gen0
- let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
- return $ VS.generate n (f . readW64)
-
-orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
-orSumOuter1 (sn@SNat :: SNat n) =
- let n = fromSNat' sn
- in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
-
-rshTail :: ShR (n + 1) i -> ShR n i
-rshTail (_ :$: sh) = sh
-rshTail ZSR = error "unreachable"
main :: IO ()
main = defaultMain $
testGroup "Tests"
- [testGroup "C"
- [testGroup "sum"
- [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
- -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
- let inrank = SNat @(n + 1)
- sh <- forAll $ genShR inrank
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
- genStorables (Range.singleton (product sh))
- (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- -- annotateShow rarr
- Refl <- return $ I.lemRankReplicate outrank
- let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
- let rhs = orSumOuter1 outrank arr
- -- annotateShow lhs
- -- annotateShow rhs
- lhs === rhs
-
- ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
- -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
- outrank :: SNat n <- return $ SNat @(nm1 + 1)
- let inrank = SNat @(n + 1)
- sh <- forAll $ do
- shtt <- genShR outrankm1 -- nm1
- sht <- shuffleShR (0 :$: shtt) -- n
- n <- Gen.int (Range.linear 0 20)
- return (n :$: sht) -- n + 1
- guard (any (== 0) (toList (rshTail sh)))
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- let arr = OR.fromList @Double @(n + 1) (toList sh) []
- let rarr = rfromOrthotope inrank arr
- Refl <- return $ I.lemRankReplicate outrank
- let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
- OR.toList lhs === []
- ]
- ]
- ]
+ [Tests.C.tests]
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
new file mode 100644
index 0000000..1041b2a
--- /dev/null
+++ b/test/Tests/C.hs
@@ -0,0 +1,73 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.C where
+
+import Control.Monad
+import qualified Data.Array.RankedS as OR
+import Data.Foldable (toList)
+import Data.Type.Equality
+import Foreign
+import GHC.TypeLits
+
+import qualified Data.Array.Mixed as X
+import Data.Array.Nested
+import qualified Data.Array.Nested.Internal as I
+
+import Hedgehog
+import Hedgehog.Internal.Property (forAllT)
+import qualified Hedgehog.Gen as Gen
+import qualified Hedgehog.Range as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+import Util
+
+
+tests :: TestTree
+tests = testGroup "C"
+ [testGroup "sum"
+ [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ -- annotateShow rarr
+ Refl <- return $ I.lemRankReplicate outrank
+ let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
+ let rhs = orSumOuter1 outrank arr
+ -- annotateShow lhs
+ -- annotateShow rhs
+ lhs === rhs
+
+ ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (any (== 0) (toList (rshTail sh)))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @Double @(n + 1) (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ Refl <- return $ I.lemRankReplicate outrank
+ let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
+ OR.toList lhs === []
+ ]
+ ]
diff --git a/test/Util.hs b/test/Util.hs
new file mode 100644
index 0000000..1249bf9
--- /dev/null
+++ b/test/Util.hs
@@ -0,0 +1,38 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Util where
+
+import qualified Data.Array.RankedS as OR
+import GHC.TypeLits
+
+import Data.Array.Mixed (fromSNat')
+import Data.Array.Nested
+
+
+-- Returns highest value that satisfies the predicate, or `lo` if none does
+binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
+binarySearch div2 = \lo hi f -> case (f lo, f hi) of
+ (False, _) -> lo
+ (_, True) -> hi
+ (_, _ ) -> go lo hi f
+ where
+ go lo hi f = -- invariant: f lo && not (f hi)
+ let mid = lo + div2 (hi - lo)
+ in if mid `elem` [lo, hi]
+ then mid
+ else if f mid then go mid hi f else go lo mid f
+
+orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
+orSumOuter1 (sn@SNat :: SNat n) =
+ let n = fromSNat' sn
+ in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+rshTail :: ShR (n + 1) i -> ShR n i
+rshTail (_ :$: sh) = sh
+rshTail ZSR = error "unreachable"