aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-03 18:09:36 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-03 18:09:36 +0200
commitcbda266091d45e564fc91462856e4f0571d18aca (patch)
tree5a7ffb173076dc5f0bbe38b89f6ad9fbb2200647
parent952228e1b598f2a7e635f41e6ecd87e81145781e (diff)
Some more generators for tests
-rw-r--r--src/Data/Array/Mixed/Types.hs5
-rw-r--r--test/Gen.hs44
2 files changed, 48 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index 52201df..35e6fd3 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -18,7 +18,7 @@ module Data.Array.Mixed.Types (
-- * Type-level naturals
pattern SZ, pattern SS,
- fromSNat',
+ fromSNat', sameNat',
snatPlus, snatMul,
-- * Type-level lists
@@ -46,6 +46,9 @@ data Dict c a where
fromSNat' :: SNat n -> Int
fromSNat' = fromIntegral . fromSNat
+sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
+sameNat' n@SNat m@SNat = sameNat n m
+
pattern SZ :: () => (n ~ 0) => SNat n
pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
where SZ = SNat
diff --git a/test/Gen.hs b/test/Gen.hs
index 5f84ce0..bcd7a4e 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -14,11 +14,14 @@ module Gen where
import Data.ByteString qualified as BS
import Data.Foldable (toList)
+import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested
@@ -84,3 +87,44 @@ genStorables rng f = do
let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
return $ VS.generate n (f . readW64)
+genStaticShX :: (forall sh. StaticShX sh -> PropertyT IO ()) -> PropertyT IO ()
+genStaticShX = \k -> genRank (\sn -> go sn k)
+ where
+ go :: SNat n -> (forall sh. StaticShX sh -> PropertyT IO ()) -> PropertyT IO ()
+ go SZ k = k ZKX
+ go (SS n) k =
+ genItem $ \item ->
+ go n $ \ssh ->
+ k (item :!% ssh)
+
+ genItem :: (forall n. SMayNat () SNat n -> PropertyT IO ()) -> PropertyT IO ()
+ genItem k = do
+ b <- forAll Gen.bool
+ if b
+ then do
+ n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4))
+ ,(1, return 0)]
+ TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn)
+ else k (SUnknown ())
+
+genShX :: StaticShX sh -> Gen (IShX sh)
+genShX ZKX = return ZSX
+genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh
+genShX (SUnknown () :!% ssh) = do
+ dim <- Gen.int (Range.linear 1 4)
+ (SUnknown dim :$%) <$> genShX ssh
+
+genPermR :: Int -> Gen PermR
+genPermR n = Gen.shuffle [0 .. n-1]
+
+genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
+genPerm n@SNat k = do
+ list <- forAll $ genPermR (fromSNat' n)
+ permFromList list $ \perm -> do
+ case permCheckPermutation perm $
+ case sameNat' (permLengthSNat perm) n of
+ Just Refl -> Just (k perm)
+ Nothing -> Nothing
+ of
+ Just (Just act) -> act
+ _ -> error ""