aboutsummaryrefslogtreecommitdiff
path: root/test/Gen.hs
blob: 2d2a30babb72c511f81b9cd0b53544334eafb2fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Gen where

import qualified Data.ByteString as BS
import Data.Foldable (toList)
import qualified Data.Vector.Storable as VS
import Foreign
import GHC.TypeLits
import qualified GHC.TypeNats as TN

import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
import Data.Array.Nested

import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import qualified System.Random as Random

import Util


genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
genRank k = do
  rank <- forAll $ Gen.int (Range.linear 0 8)
  TN.withSomeSNat (fromIntegral rank) k

genLowBiased :: RealFloat a => (a, a) -> Gen a
genLowBiased (lo, hi) = do
  x <- Gen.realFloat (Range.linearFrac 0 1)
  return (lo + x * x * x * (hi - lo))

shuffleShR :: IShR n -> Gen (IShR n)
shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
  where
    go :: Int -> [Int] -> IShR n -> Gen (IShR n)
    go _    _   ZSR = return ZSR
    go nbag bag (_ :$: sh) = do
      idx <- Gen.int (Range.linear 0 (nbag - 1))
      let (dim, bag') = case splitAt idx bag of
                          (pre, n : post) -> (n, pre ++ post)
                          _ -> error "unreachable"
      (dim :$:) <$> go (nbag - 1) bag' sh

genShR :: SNat n -> Gen (IShR n)
genShR sn = do
  let n = fromSNat' sn
  targetSize <- Gen.int (Range.linear 0 100_000)
  let genDims :: SNat m -> Int -> Gen (IShR m)
      genDims SZ _ = return ZSR
      genDims (SS m) 0 = do
        dim <- Gen.int (Range.linear 0 20)
        dims <- genDims m 0
        return (dim :$: dims)
      genDims (SS m) tgt = do
        dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
                             ,(2     , return tgt)
                             ,(4     , return 1)
                             ,(1     , return 0)]
        dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
        return (dim :$: dims)
  dims <- genDims sn targetSize
  let dimsL = toList dims
      maxdim = maximum dimsL
      cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
  shuffleShR (min cap <$> dims)

genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
  n <- Gen.int rng
  seed <- Gen.resize 99 $ Gen.int Range.linearBounded
  let gen0 = Random.mkStdGen seed
      (bs, _) = Random.genByteString (8 * n) gen0
  let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
  return $ VS.generate n (f . readW64)