1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Gen where
import Data.ByteString qualified as BS
import Data.Foldable (toList)
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested
import Hedgehog
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import System.Random qualified as Random
import Util
genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
genRank k = do
rank <- forAll $ Gen.int (Range.linear 0 8)
TN.withSomeSNat (fromIntegral rank) k
genLowBiased :: RealFloat a => (a, a) -> Gen a
genLowBiased (lo, hi) = do
x <- Gen.realFloat (Range.linearFrac 0 1)
return (lo + x * x * x * (hi - lo))
shuffleShR :: IShR n -> Gen (IShR n)
shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
where
go :: Int -> [Int] -> IShR n -> Gen (IShR n)
go _ _ ZSR = return ZSR
go nbag bag (_ :$: sh) = do
idx <- Gen.int (Range.linear 0 (nbag - 1))
let (dim, bag') = case splitAt idx bag of
(pre, n : post) -> (n, pre ++ post)
_ -> error "unreachable"
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
genShR sn = do
let n = fromSNat' sn
targetSize <- Gen.int (Range.linear 0 100_000)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
dim <- Gen.int (Range.linear 0 20)
dims <- genDims m 0
return (dim :$: dims)
genDims (SS m) tgt = do
dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
,(2 , return tgt)
,(4 , return 1)
,(1 , return 0)]
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
dims <- genDims sn targetSize
let dimsL = toList dims
maxdim = maximum dimsL
cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
shuffleShR (min cap <$> dims)
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
n <- Gen.int rng
seed <- Gen.resize 99 $ Gen.int Range.linearBounded
let gen0 = Random.mkStdGen seed
(bs, _) = Random.genByteString (8 * n) gen0
let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
return $ VS.generate n (f . readW64)
genStaticShX :: (forall sh. StaticShX sh -> PropertyT IO ()) -> PropertyT IO ()
genStaticShX = \k -> genRank (\sn -> go sn k)
where
go :: SNat n -> (forall sh. StaticShX sh -> PropertyT IO ()) -> PropertyT IO ()
go SZ k = k ZKX
go (SS n) k =
genItem $ \item ->
go n $ \ssh ->
k (item :!% ssh)
genItem :: (forall n. SMayNat () SNat n -> PropertyT IO ()) -> PropertyT IO ()
genItem k = do
b <- forAll Gen.bool
if b
then do
n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4))
,(1, return 0)]
TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn)
else k (SUnknown ())
genShX :: StaticShX sh -> Gen (IShX sh)
genShX ZKX = return ZSX
genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh
genShX (SUnknown () :!% ssh) = do
dim <- Gen.int (Range.linear 1 4)
(SUnknown dim :$%) <$> genShX ssh
genPermR :: Int -> Gen PermR
genPermR n = Gen.shuffle [0 .. n-1]
genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
genPerm n@SNat k = do
list <- forAll $ genPermR (fromSNat' n)
permFromList list $ \perm -> do
case permCheckPermutation perm $
case sameNat' (permLengthSNat perm) n of
Just Refl -> Just (k perm)
Nothing -> Nothing
of
Just (Just act) -> act
_ -> error ""
|