diff options
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 21 | ||||
| -rw-r--r-- | test/Gen.hs | 2 | ||||
| -rw-r--r-- | test/Tests/Permutation.hs | 2 |
4 files changed, 21 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 045b18f..9eae73d 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -57,11 +57,11 @@ permRank :: Perm list -> SNat (Rank list) permRank PNil = SNat permRank (_ `PCons` l) | SNat <- permRank l = SNat -permFromList :: [Int] -> (forall list. Perm list -> r) -> r -permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x +permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r +permFromListCont [] k = k PNil +permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromListCont: negative number in list: " ++ show x permToList :: Perm list -> [Natural] permToList PNil = mempty diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 739f0de..4d581a0 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -7,6 +7,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} @@ -32,6 +33,7 @@ import Control.DeepSeq (NFData(..)) import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) +import Data.List (genericLength) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, quotRemInt#) @@ -128,9 +130,12 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) +listrFromList :: SNat n -> [i] -> ListR n i +listrFromList SZ [] = ZR +listrFromList (SS n) (i : is) = i ::: listrFromList n is +listrFromList n l = error $ "listrFromList: Mismatched list length (type says " + ++ show (fromSNat n) ++ ", list has length " + ++ show (length l) ++ ")" listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i @@ -169,14 +174,16 @@ listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of +listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case + Just permlen@SNat-> + let sperm = listrFromList permlen perm + in case listrRank sh of + shlen@SNat -> case cmpNat permlen shlen of LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + Nothing -> error "listrPermutePrefix: impossible negative list length" where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) diff --git a/test/Gen.hs b/test/Gen.hs index e665af6..b10e763 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -163,7 +163,7 @@ genPermR n = Gen.shuffle [0 .. n-1] genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r genPerm n@SNat k = do list <- forAll $ genPermR (fromSNat' n) - permFromList list $ \perm -> do + permFromListCont list $ \perm -> do case permCheckPermutation perm $ case sameNat' (permRank perm) n of Just Refl -> Just (k perm) diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs index 98a6da5..4e75d64 100644 --- a/test/Tests/Permutation.hs +++ b/test/Tests/Permutation.hs @@ -24,7 +24,7 @@ tests = testGroup "Permutation" [testProperty "permCheckPermutation" $ property $ do n <- forAll $ Gen.int (Range.linear 0 10) list <- forAll $ genPermR n - let r = permFromList list $ \perm -> + let r = permFromListCont list $ \perm -> permCheckPermutation perm () case r of Just () -> return () |
