diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 739f0de..4d581a0 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -7,6 +7,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} @@ -32,6 +33,7 @@ import Control.DeepSeq (NFData(..)) import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) +import Data.List (genericLength) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, quotRemInt#) @@ -128,9 +130,12 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) +listrFromList :: SNat n -> [i] -> ListR n i +listrFromList SZ [] = ZR +listrFromList (SS n) (i : is) = i ::: listrFromList n is +listrFromList n l = error $ "listrFromList: Mismatched list length (type says " + ++ show (fromSNat n) ++ ", list has length " + ++ show (length l) ++ ")" listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i @@ -169,14 +174,16 @@ listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of +listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case + Just permlen@SNat-> + let sperm = listrFromList permlen perm + in case listrRank sh of + shlen@SNat -> case cmpNat permlen shlen of LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + Nothing -> error "listrPermutePrefix: impossible negative list length" where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) |
