diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-04-04 10:33:50 +0200 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-04-04 10:33:50 +0200 |
| commit | a9ac62f66e45e64f83043e0ebda04f0b4b80b913 (patch) | |
| tree | 4de2974a7753e97c1f1040af72f49af904ad9570 /src/Data/Array/Nested/Ranked/Shape.hs | |
| parent | 2095a851760b6bb44ba92b70df1efceff1bad267 (diff) | |
Make ranked and shaped lists newtypes over mixed
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 77 |
1 files changed, 37 insertions, 40 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index b8b5a28..b0fb30d 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -16,6 +16,7 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -49,13 +50,37 @@ import Data.Array.Nested.Types type role ListR nominal representational type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) +newtype ListR n i = ListR (ListX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor, Foldable) + +pattern ZR :: forall n i. () => n ~ 0 => ListR n i +pattern ZR <- ListR (matchZX @n -> Just Refl) + where ZR = ListR ZX + +matchZX :: forall n i. ListX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZX ZX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZX _ = Nothing + +pattern (:::) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ListR n i -> ListR n1 i +pattern i ::: sh <- (listrUncons -> Just (UnconsListRRes sh i)) + where i ::: ListR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ListR (i ::% sh) infixr 3 ::: +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: forall n1 i. ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (ListR ((::%) @n' @sh' x sh')) + | Refl <- lemReplicateHead (Proxy @n') (Proxy @sh') (Proxy @Nothing) (Proxy @n1) Refl + , Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl = + Just (UnconsListRRes (ListR @(Rank sh') sh') x) +listrUncons (ListR _) = Nothing + +{-# COMPLETE ZR, (:::) #-} + #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ListR n i) #else @@ -63,32 +88,6 @@ instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows #endif -instance NFData i => NFData (ListR n i) where - rnf ZR = () - rnf (x ::: l) = rnf x `seq` rnf l - -instance Functor (ListR n) where - {-# INLINE fmap #-} - fmap _ ZR = ZR - fmap f (x ::: xs) = f x ::: fmap f xs - -instance Foldable (ListR n) where - {-# INLINE foldMap #-} - foldMap _ ZR = mempty - foldMap f (x ::: xs) = f x <> foldMap f xs - {-# INLINE foldr #-} - foldr _ z ZR = z - foldr f z (x ::: xs) = f x (foldr f z xs) - toList = listrToList - null ZR = False - null _ = True - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - -- | This checks only whether the ranks are equal, not whether the actual -- values are. listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') @@ -122,7 +121,7 @@ listrRank :: ListR n i -> SNat n listrRank ZR = SNat listrRank (_ ::: sh) = snatSucc (listrRank sh) -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh @@ -185,7 +184,7 @@ listrSplitAt SZ sh = (ZR, sh) listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) listrSplitAt SS{} ZR = error "m' + 1 <= 0" -listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i +listrPermutePrefix :: forall n i. PermR -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case listrRank sh of { shlen@SNat -> @@ -273,7 +272,7 @@ ixrCast :: SNat n' -> IxR n i -> IxR n' i ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) +ixrAppend = coerce (listrAppend @n @m @i) ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 @@ -283,7 +282,7 @@ ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) +ixrPermutePrefix = coerce (listrPermutePrefix @n @i) -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. @@ -332,9 +331,9 @@ pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i)) - where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl - = ShR (SUnknown i :$% shl) +pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes sh i)) + where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh) +infixr 3 :$: data UnconsShRRes i n1 = forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i @@ -345,8 +344,6 @@ shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) = Just (UnconsShRRes (ShR sh') x) shrUncons (ShR _) = Nothing -infixr 3 :$: - {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int |
