aboutsummaryrefslogtreecommitdiff
path: root/test/Gen.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Gen.hs')
-rw-r--r--test/Gen.hs44
1 files changed, 44 insertions, 0 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
index 5f84ce0..bcd7a4e 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -14,11 +14,14 @@ module Gen where
import Data.ByteString qualified as BS
import Data.Foldable (toList)
+import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested
@@ -84,3 +87,44 @@ genStorables rng f = do
let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
return $ VS.generate n (f . readW64)
+genStaticShX :: (forall sh. StaticShX sh -> PropertyT IO ()) -> PropertyT IO ()
+genStaticShX = \k -> genRank (\sn -> go sn k)
+ where
+ go :: SNat n -> (forall sh. StaticShX sh -> PropertyT IO ()) -> PropertyT IO ()
+ go SZ k = k ZKX
+ go (SS n) k =
+ genItem $ \item ->
+ go n $ \ssh ->
+ k (item :!% ssh)
+
+ genItem :: (forall n. SMayNat () SNat n -> PropertyT IO ()) -> PropertyT IO ()
+ genItem k = do
+ b <- forAll Gen.bool
+ if b
+ then do
+ n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4))
+ ,(1, return 0)]
+ TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn)
+ else k (SUnknown ())
+
+genShX :: StaticShX sh -> Gen (IShX sh)
+genShX ZKX = return ZSX
+genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh
+genShX (SUnknown () :!% ssh) = do
+ dim <- Gen.int (Range.linear 1 4)
+ (SUnknown dim :$%) <$> genShX ssh
+
+genPermR :: Int -> Gen PermR
+genPermR n = Gen.shuffle [0 .. n-1]
+
+genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
+genPerm n@SNat k = do
+ list <- forAll $ genPermR (fromSNat' n)
+ permFromList list $ \perm -> do
+ case permCheckPermutation perm $
+ case sameNat' (permLengthSNat perm) n of
+ Just Refl -> Just (k perm)
+ Nothing -> Nothing
+ of
+ Just (Just act) -> act
+ _ -> error ""