aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs21
1 files changed, 14 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 739f0de..4d581a0 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -32,6 +33,7 @@ import Control.DeepSeq (NFData(..))
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
+import Data.List (genericLength)
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, quotRemInt#)
@@ -128,9 +130,12 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+listrFromList :: SNat n -> [i] -> ListR n i
+listrFromList SZ [] = ZR
+listrFromList (SS n) (i : is) = i ::: listrFromList n is
+listrFromList n l = error $ "listrFromList: Mismatched list length (type says "
+ ++ show (fromSNat n) ++ ", list has length "
+ ++ show (length l) ++ ")"
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
@@ -169,14 +174,16 @@ listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case
+ Just permlen@SNat->
+ let sperm = listrFromList permlen perm
+ in case listrRank sh of
+ shlen@SNat -> case cmpNat permlen shlen of
LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ Nothing -> error "listrPermutePrefix: impossible negative list length"
where
listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
listrSplitAt SZ sh = (ZR, sh)