aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-11 14:08:18 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-11 18:04:06 +0100
commit4aa8646599f51bbfa2006fd68738713fbb8f215a (patch)
treebafff761448b3951904137312a4e35fdffc5f731 /src/Data
parentdab29560cbd4d79577d1a1bff354c2813bbbd2c0 (diff)
Remove KnownNat from ListS and express ListS as newtype over ListX
as sketched by Tom.
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs58
1 files changed, 39 insertions, 19 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0644953..f3e2c1e 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -52,16 +52,34 @@ import Data.Array.Nested.Types
-- * Shaped lists
--- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
--- removed in a future release.
-type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+type role ListS nominal nominal
+newtype ListS sh f = ListS (ListX (MapJust sh) (WrapJust f))
deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+
+type role WrapJust nominal nominal
+data WrapJust f n where
+ WrapJust :: f n -> WrapJust f (Just n)
+deriving instance (forall m. Eq (f m)) => Eq (WrapJust f n)
+deriving instance (forall m. Ord (f m)) => Ord (WrapJust f n)
+
+pattern ZS :: () => sh ~ '[] => ListS sh f
+pattern ZS <- ListS (matchZS -> Just Refl)
+ where ZS = ListS ZX
+
+matchZS :: forall sh f. ListX (MapJust sh) f -> Maybe (sh :~: '[])
+matchZS ZX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZS _ = Nothing
+
+pattern (::$)
+ :: forall {sh1} {f}.
+ forall n sh. (n : sh ~ sh1)
+ => f n -> ListS sh f -> ListS sh1 f
+pattern n ::$ sh <- (listsUncons -> Just (UnconsListSRes sh n))
+ where n ::$ ListS sh = ListS (WrapJust n ::% sh)
+
+{-# COMPLETE ZS, (::$) #-}
+
infixr 3 ::$
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -76,10 +94,12 @@ instance (forall m. NFData (f m)) => NFData (ListS n f) where
rnf (x ::$ l) = rnf x `seq` rnf l
data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
-listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
-listsUncons ZS = Nothing
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+listsUncons :: forall sh1 f. ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+listsUncons (ListS (WrapJust x ::% sh'))
+ | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsListSRes (ListS sh') x)
+listsUncons (ListS ZX) = Nothing
-- | This checks only whether the types are equal; if the elements of the list
-- are not singletons, their values may still differ. This corresponds to
@@ -188,11 +208,11 @@ listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
+ item -> item ::$ listsPermute is sh
-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh)
+listsIndex _ _ SZ (n ::$ _) = n
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listsIndex p pT i sh
@@ -204,7 +224,7 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe
-- * Shaped indices
-- | An index into a shape-typed array.
-type role IxS nominal representational
+type role IxS nominal nominal
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
deriving (Eq, Ord, Generic)
@@ -216,7 +236,7 @@ pattern ZIS = IxS ZS
-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
where i :.$ IxS shl = IxS (Const i ::$ shl)
@@ -331,7 +351,7 @@ pattern ZSS = ShS ZS
pattern (:$$)
:: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
where i :$$ ShS shl = ShS (i ::$ shl)
@@ -414,7 +434,7 @@ shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)
shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh))
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
shsPermutePrefix = coerce (listsPermutePrefix @SNat)