diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 |
commit | c5108efd1402dcb52beca27d13b4880eed35ef5b (patch) | |
tree | b25e4ee26c1f894671db2e68c0afdaf6a1378cb5 /test/Gen.hs | |
parent | 0fd727dcb3fe05816aa9c68be5ebac84a55fcf4b (diff) |
Properly test C reductions
Diffstat (limited to 'test/Gen.hs')
-rw-r--r-- | test/Gen.hs | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/test/Gen.hs b/test/Gen.hs index 29dffb2..559fecf 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -15,6 +15,7 @@ module Gen where import Data.ByteString qualified as BS import Data.Foldable (toList) import Data.Type.Equality +import Data.Type.Ord import Data.Vector.Storable qualified as VS import Foreign import GHC.TypeLits @@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested +import Data.Array.Nested.Internal.Shape import Hedgehog import Hedgehog.Gen qualified as Gen @@ -81,6 +83,42 @@ genShR sn = do cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) shuffleShR (min cap <$> dims) +-- | Example: given 3 and 7, might return: +-- +-- @ +-- ([ 13, 4, 27 ] +-- ,[1, 13, 1, 1, 4, 27, 1] +-- ,[4, 13, 1, 3, 4, 27, 2]) +-- @ +-- +-- The up-replicated dimensions are always nonzero and not very large, but the +-- other dimensions might be zero. +genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) +genReplicatedShR = \m n -> do + sh1 <- genShR m + (sh2, sh3) <- injectOnes n sh1 sh1 + return (sh1, sh2, sh3) + where + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) + injectOnes n@SNat shOnes sh + | m@SNat <- shrLengthSNat sh + = case cmpNat n m of + LTI -> error "unreachable" + EQI -> return (shOnes, sh) + GTI -> do + index <- Gen.int (Range.linear 0 (fromSNat' m)) + value <- Gen.int (Range.linear 1 5) + Refl <- return (lem n m) + injectOnes n (inject index 1 shOnes) (inject index value sh) + + lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem _ _ = unsafeCoerceRefl + + inject :: Int -> Int -> IShR m -> IShR (m + 1) + inject 0 v sh = v :$: sh + inject i v (w :$: sh) = w :$: inject (i - 1) v sh + inject _ v ZSR = v :$: ZSR -- invalid input, but meh + genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng |