aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-01 17:50:42 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-01 18:14:40 +0100
commit9faf7fb877119bd52d664940c4326d326b3326fa (patch)
tree58063f871b11a4dd5089dca28bb561bc322a875f
parent0028b655341069e83db6e0bfde01dea1c696f5aa (diff)
Don't call continuation-based functions just *FromList
-rw-r--r--src/Data/Array/Nested/Permutation.hs10
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs21
-rw-r--r--test/Gen.hs2
-rw-r--r--test/Tests/Permutation.hs2
4 files changed, 21 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 045b18f..9eae73d 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -57,11 +57,11 @@ permRank :: Perm list -> SNat (Rank list)
permRank PNil = SNat
permRank (_ `PCons` l) | SNat <- permRank l = SNat
-permFromList :: [Int] -> (forall list. Perm list -> r) -> r
-permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r
+permFromListCont [] k = k PNil
+permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Mixed.permFromListCont: negative number in list: " ++ show x
permToList :: Perm list -> [Natural]
permToList PNil = mempty
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 739f0de..4d581a0 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -32,6 +33,7 @@ import Control.DeepSeq (NFData(..))
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
+import Data.List (genericLength)
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, quotRemInt#)
@@ -128,9 +130,12 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+listrFromList :: SNat n -> [i] -> ListR n i
+listrFromList SZ [] = ZR
+listrFromList (SS n) (i : is) = i ::: listrFromList n is
+listrFromList n l = error $ "listrFromList: Mismatched list length (type says "
+ ++ show (fromSNat n) ++ ", list has length "
+ ++ show (length l) ++ ")"
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
@@ -169,14 +174,16 @@ listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case
+ Just permlen@SNat->
+ let sperm = listrFromList permlen perm
+ in case listrRank sh of
+ shlen@SNat -> case cmpNat permlen shlen of
LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ Nothing -> error "listrPermutePrefix: impossible negative list length"
where
listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
listrSplitAt SZ sh = (ZR, sh)
diff --git a/test/Gen.hs b/test/Gen.hs
index e665af6..b10e763 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -163,7 +163,7 @@ genPermR n = Gen.shuffle [0 .. n-1]
genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
genPerm n@SNat k = do
list <- forAll $ genPermR (fromSNat' n)
- permFromList list $ \perm -> do
+ permFromListCont list $ \perm -> do
case permCheckPermutation perm $
case sameNat' (permRank perm) n of
Just Refl -> Just (k perm)
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 98a6da5..4e75d64 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -24,7 +24,7 @@ tests = testGroup "Permutation"
[testProperty "permCheckPermutation" $ property $ do
n <- forAll $ Gen.int (Range.linear 0 10)
list <- forAll $ genPermR n
- let r = permFromList list $ \perm ->
+ let r = permFromListCont list $ \perm ->
permCheckPermutation perm ()
case r of
Just () -> return ()