diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 21:46:34 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 21:51:22 +0200 |
commit | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (patch) | |
tree | 64dcb00c9c61ad57177db5ec01c189d74dbc2d4a /test/Gen.hs | |
parent | 5a802da40e5836ee19d46b9a2c771912dbff010e (diff) |
Reorganise test files
Diffstat (limited to 'test/Gen.hs')
-rw-r--r-- | test/Gen.hs | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/test/Gen.hs b/test/Gen.hs new file mode 100644 index 0000000..2d2a30b --- /dev/null +++ b/test/Gen.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Gen where + +import qualified Data.ByteString as BS +import Data.Foldable (toList) +import qualified Data.Vector.Storable as VS +import Foreign +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) +import Data.Array.Nested + +import Hedgehog +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import qualified System.Random as Random + +import Util + + +genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () +genRank k = do + rank <- forAll $ Gen.int (Range.linear 0 8) + TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do + x <- Gen.realFloat (Range.linearFrac 0 1) + return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh + where + go :: Int -> [Int] -> IShR n -> Gen (IShR n) + go _ _ ZSR = return ZSR + go nbag bag (_ :$: sh) = do + idx <- Gen.int (Range.linear 0 (nbag - 1)) + let (dim, bag') = case splitAt idx bag of + (pre, n : post) -> (n, pre ++ post) + _ -> error "unreachable" + (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR sn = do + let n = fromSNat' sn + targetSize <- Gen.int (Range.linear 0 100_000) + let genDims :: SNat m -> Int -> Gen (IShR m) + genDims SZ _ = return ZSR + genDims (SS m) 0 = do + dim <- Gen.int (Range.linear 0 20) + dims <- genDims m 0 + return (dim :$: dims) + genDims (SS m) tgt = do + dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) + ,(2 , return tgt) + ,(4 , return 1) + ,(1 , return 0)] + dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) + return (dim :$: dims) + dims <- genDims sn targetSize + let dimsL = toList dims + maxdim = maximum dimsL + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + shuffleShR (min cap <$> dims) + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do + n <- Gen.int rng + seed <- Gen.resize 99 $ Gen.int Range.linearBounded + let gen0 = Random.mkStdGen seed + (bs, _) = Random.genByteString (8 * n) gen0 + let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) + return $ VS.generate n (f . readW64) + |