aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Shape.hs
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2026-04-04 10:33:50 +0200
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-04-04 10:33:50 +0200
commita9ac62f66e45e64f83043e0ebda04f0b4b80b913 (patch)
tree4de2974a7753e97c1f1040af72f49af904ad9570 /src/Data/Array/Nested/Ranked/Shape.hs
parent2095a851760b6bb44ba92b70df1efceff1bad267 (diff)
Make ranked and shaped lists newtypes over mixed
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs77
1 files changed, 37 insertions, 40 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index b8b5a28..b0fb30d 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -49,13 +50,37 @@ import Data.Array.Nested.Types
type role ListR nominal representational
type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
+newtype ListR n i = ListR (ListX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
+
+pattern ZR :: forall n i. () => n ~ 0 => ListR n i
+pattern ZR <- ListR (matchZX @n -> Just Refl)
+ where ZR = ListR ZX
+
+matchZX :: forall n i. ListX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZX ZX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZX _ = Nothing
+
+pattern (:::)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ListR n i -> ListR n1 i
+pattern i ::: sh <- (listrUncons -> Just (UnconsListRRes sh i))
+ where i ::: ListR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ListR (i ::% sh)
infixr 3 :::
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: forall n1 i. ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (ListR ((::%) @n' @sh' x sh'))
+ | Refl <- lemReplicateHead (Proxy @n') (Proxy @sh') (Proxy @Nothing) (Proxy @n1) Refl
+ , Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl =
+ Just (UnconsListRRes (ListR @(Rank sh') sh') x)
+listrUncons (ListR _) = Nothing
+
+{-# COMPLETE ZR, (:::) #-}
+
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show i => Show (ListR n i)
#else
@@ -63,32 +88,6 @@ instance Show i => Show (ListR n i) where
showsPrec _ = listrShow shows
#endif
-instance NFData i => NFData (ListR n i) where
- rnf ZR = ()
- rnf (x ::: l) = rnf x `seq` rnf l
-
-instance Functor (ListR n) where
- {-# INLINE fmap #-}
- fmap _ ZR = ZR
- fmap f (x ::: xs) = f x ::: fmap f xs
-
-instance Foldable (ListR n) where
- {-# INLINE foldMap #-}
- foldMap _ ZR = mempty
- foldMap f (x ::: xs) = f x <> foldMap f xs
- {-# INLINE foldr #-}
- foldr _ z ZR = z
- foldr f z (x ::: xs) = f x (foldr f z xs)
- toList = listrToList
- null ZR = False
- null _ = True
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
@@ -122,7 +121,7 @@ listrRank :: ListR n i -> SNat n
listrRank ZR = SNat
listrRank (_ ::: sh) = snatSucc (listrRank sh)
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
@@ -185,7 +184,7 @@ listrSplitAt SZ sh = (ZR, sh)
listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i
+listrPermutePrefix :: forall n i. PermR -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
case listrRank sh of { shlen@SNat ->
@@ -273,7 +272,7 @@ ixrCast :: SNat n' -> IxR n i -> IxR n' i
ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @_ @i)
+ixrAppend = coerce (listrAppend @n @m @i)
ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
@@ -283,7 +282,7 @@ ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
+ixrPermutePrefix = coerce (listrPermutePrefix @n @i)
-- | Given a multidimensional index, get the corresponding linear
-- index into the buffer.
@@ -332,9 +331,9 @@ pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i))
- where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl
- = ShR (SUnknown i :$% shl)
+pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes sh i))
+ where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh)
+infixr 3 :$:
data UnconsShRRes i n1 =
forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i
@@ -345,8 +344,6 @@ shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
= Just (UnconsShRRes (ShR sh') x)
shrUncons (ShR _) = Nothing
-infixr 3 :$:
-
{-# COMPLETE ZSR, (:$:) #-}
type IShR n = ShR n Int