{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Gen where import qualified Data.ByteString as BS import Data.Foldable (toList) import qualified Data.Vector.Storable as VS import Foreign import GHC.TypeLits import qualified GHC.TypeNats as TN import Data.Array.Mixed.Types import Data.Array.Nested import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import qualified System.Random as Random import Util genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () genRank k = do rank <- forAll $ Gen.int (Range.linear 0 8) TN.withSomeSNat (fromIntegral rank) k genLowBiased :: RealFloat a => (a, a) -> Gen a genLowBiased (lo, hi) = do x <- Gen.realFloat (Range.linearFrac 0 1) return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR go nbag bag (_ :$: sh) = do idx <- Gen.int (Range.linear 0 (nbag - 1)) let (dim, bag') = case splitAt idx bag of (pre, n : post) -> (n, pre ++ post) _ -> error "unreachable" (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) genShR sn = do let n = fromSNat' sn targetSize <- Gen.int (Range.linear 0 100_000) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do dim <- Gen.int (Range.linear 0 20) dims <- genDims m 0 return (dim :$: dims) genDims (SS m) tgt = do dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) ,(2 , return tgt) ,(4 , return 1) ,(1 , return 0)] dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) dims <- genDims sn targetSize let dimsL = toList dims maxdim = maximum dimsL cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) shuffleShR (min cap <$> dims) genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng seed <- Gen.resize 99 $ Gen.int Range.linearBounded let gen0 = Random.mkStdGen seed (bs, _) = Random.genByteString (8 * n) gen0 let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) return $ VS.generate n (f . readW64)