aboutsummaryrefslogtreecommitdiff
path: root/test/Gen.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Gen.hs')
-rw-r--r--test/Gen.hs85
1 files changed, 85 insertions, 0 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
new file mode 100644
index 0000000..2d2a30b
--- /dev/null
+++ b/test/Gen.hs
@@ -0,0 +1,85 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Gen where
+
+import qualified Data.ByteString as BS
+import Data.Foldable (toList)
+import qualified Data.Vector.Storable as VS
+import Foreign
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+
+import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
+import Data.Array.Nested
+
+import Hedgehog
+import qualified Hedgehog.Gen as Gen
+import qualified Hedgehog.Range as Range
+import qualified System.Random as Random
+
+import Util
+
+
+genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
+genRank k = do
+ rank <- forAll $ Gen.int (Range.linear 0 8)
+ TN.withSomeSNat (fromIntegral rank) k
+
+genLowBiased :: RealFloat a => (a, a) -> Gen a
+genLowBiased (lo, hi) = do
+ x <- Gen.realFloat (Range.linearFrac 0 1)
+ return (lo + x * x * x * (hi - lo))
+
+shuffleShR :: IShR n -> Gen (IShR n)
+shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
+ where
+ go :: Int -> [Int] -> IShR n -> Gen (IShR n)
+ go _ _ ZSR = return ZSR
+ go nbag bag (_ :$: sh) = do
+ idx <- Gen.int (Range.linear 0 (nbag - 1))
+ let (dim, bag') = case splitAt idx bag of
+ (pre, n : post) -> (n, pre ++ post)
+ _ -> error "unreachable"
+ (dim :$:) <$> go (nbag - 1) bag' sh
+
+genShR :: SNat n -> Gen (IShR n)
+genShR sn = do
+ let n = fromSNat' sn
+ targetSize <- Gen.int (Range.linear 0 100_000)
+ let genDims :: SNat m -> Int -> Gen (IShR m)
+ genDims SZ _ = return ZSR
+ genDims (SS m) 0 = do
+ dim <- Gen.int (Range.linear 0 20)
+ dims <- genDims m 0
+ return (dim :$: dims)
+ genDims (SS m) tgt = do
+ dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
+ ,(2 , return tgt)
+ ,(4 , return 1)
+ ,(1 , return 0)]
+ dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
+ return (dim :$: dims)
+ dims <- genDims sn targetSize
+ let dimsL = toList dims
+ maxdim = maximum dimsL
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ shuffleShR (min cap <$> dims)
+
+genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
+genStorables rng f = do
+ n <- Gen.int rng
+ seed <- Gen.resize 99 $ Gen.int Range.linearBounded
+ let gen0 = Random.mkStdGen seed
+ (bs, _) = Random.genByteString (8 * n) gen0
+ let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
+ return $ VS.generate n (f . readW64)
+