diff options
Diffstat (limited to 'test/Gen.hs')
-rw-r--r-- | test/Gen.hs | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/test/Gen.hs b/test/Gen.hs index 244c735..044de14 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -2,7 +2,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NumericUnderscores #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} @@ -21,11 +20,10 @@ import Foreign import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types import Data.Array.Nested -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Hedgehog import Hedgehog.Gen qualified as Gen @@ -49,7 +47,7 @@ genLowBiased (lo, hi) = do return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) -shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh +shuffleShR = \sh -> go (length sh) (toList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR @@ -61,9 +59,12 @@ shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) -genShR sn = do +genShR = genShRwithTarget 100_000 + +genShRwithTarget :: Int -> SNat n -> Gen (IShR n) +genShRwithTarget targetMax sn = do let n = fromSNat' sn - targetSize <- Gen.int (Range.linear 0 100_000) + targetSize <- Gen.int (Range.linear 0 targetMax) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do @@ -95,10 +96,14 @@ genShR sn = do -- other dimensions might be zero. genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) genReplicatedShR = \m n -> do - sh1 <- genShR m + let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m)) + sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m (sh2, sh3) <- injectOnes n sh1 sh1 return (sh1, sh2, sh3) where + repvalrange = (1::Int, 5) + repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) injectOnes n@SNat shOnes sh | m@SNat <- shrRank sh @@ -107,17 +112,17 @@ genReplicatedShR = \m n -> do EQI -> return (shOnes, sh) GTI -> do index <- Gen.int (Range.linear 0 (fromSNat' m)) - value <- Gen.int (Range.linear 1 5) + value <- Gen.int (uncurry Range.linear repvalrange) Refl <- return (lem n m) injectOnes n (inject index 1 shOnes) (inject index value sh) - lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True lem _ _ = unsafeCoerceRefl inject :: Int -> Int -> IShR m -> IShR (m + 1) inject 0 v sh = v :$: sh inject i v (w :$: sh) = w :$: inject (i - 1) v sh - inject _ v ZSR = v :$: ZSR -- invalid input, but meh + inject _ _ ZSR = error "unreachable" genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do |