aboutsummaryrefslogtreecommitdiff
path: root/test/Gen.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Gen.hs')
-rw-r--r--test/Gen.hs38
1 files changed, 38 insertions, 0 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
index 29dffb2..559fecf 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -15,6 +15,7 @@ module Gen where
import Data.ByteString qualified as BS
import Data.Foldable (toList)
import Data.Type.Equality
+import Data.Type.Ord
import Data.Vector.Storable qualified as VS
import Foreign
import GHC.TypeLits
@@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested
+import Data.Array.Nested.Internal.Shape
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -81,6 +83,42 @@ genShR sn = do
cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
shuffleShR (min cap <$> dims)
+-- | Example: given 3 and 7, might return:
+--
+-- @
+-- ([ 13, 4, 27 ]
+-- ,[1, 13, 1, 1, 4, 27, 1]
+-- ,[4, 13, 1, 3, 4, 27, 2])
+-- @
+--
+-- The up-replicated dimensions are always nonzero and not very large, but the
+-- other dimensions might be zero.
+genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
+genReplicatedShR = \m n -> do
+ sh1 <- genShR m
+ (sh2, sh3) <- injectOnes n sh1 sh1
+ return (sh1, sh2, sh3)
+ where
+ injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
+ injectOnes n@SNat shOnes sh
+ | m@SNat <- shrLengthSNat sh
+ = case cmpNat n m of
+ LTI -> error "unreachable"
+ EQI -> return (shOnes, sh)
+ GTI -> do
+ index <- Gen.int (Range.linear 0 (fromSNat' m))
+ value <- Gen.int (Range.linear 1 5)
+ Refl <- return (lem n m)
+ injectOnes n (inject index 1 shOnes) (inject index value sh)
+
+ lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+ lem _ _ = unsafeCoerceRefl
+
+ inject :: Int -> Int -> IShR m -> IShR (m + 1)
+ inject 0 v sh = v :$: sh
+ inject i v (w :$: sh) = w :$: inject (i - 1) v sh
+ inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
n <- Gen.int rng