diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-11 14:08:18 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-11 18:04:06 +0100 |
| commit | 4aa8646599f51bbfa2006fd68738713fbb8f215a (patch) | |
| tree | bafff761448b3951904137312a4e35fdffc5f731 /src/Data/Array/Nested/Shaped | |
| parent | dab29560cbd4d79577d1a1bff354c2813bbbd2c0 (diff) | |
Remove KnownNat from ListS and express ListS as newtype over ListX
as sketched by Tom.
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 58 |
1 files changed, 39 insertions, 19 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0644953..f3e2c1e 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -52,16 +52,34 @@ import Data.Array.Nested.Types -- * Shaped lists --- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be --- removed in a future release. -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f +type role ListS nominal nominal +newtype ListS sh f = ListS (ListX (MapJust sh) (WrapJust f)) deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) + +type role WrapJust nominal nominal +data WrapJust f n where + WrapJust :: f n -> WrapJust f (Just n) +deriving instance (forall m. Eq (f m)) => Eq (WrapJust f n) +deriving instance (forall m. Ord (f m)) => Ord (WrapJust f n) + +pattern ZS :: () => sh ~ '[] => ListS sh f +pattern ZS <- ListS (matchZS -> Just Refl) + where ZS = ListS ZX + +matchZS :: forall sh f. ListX (MapJust sh) f -> Maybe (sh :~: '[]) +matchZS ZX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZS _ = Nothing + +pattern (::$) + :: forall {sh1} {f}. + forall n sh. (n : sh ~ sh1) + => f n -> ListS sh f -> ListS sh1 f +pattern n ::$ sh <- (listsUncons -> Just (UnconsListSRes sh n)) + where n ::$ ListS sh = ListS (WrapJust n ::% sh) + +{-# COMPLETE ZS, (::$) #-} + infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -76,10 +94,12 @@ instance (forall m. NFData (f m)) => NFData (ListS n f) where rnf (x ::$ l) = rnf x `seq` rnf l data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) -listsUncons ZS = Nothing + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +listsUncons :: forall sh1 f. ListS sh1 f -> Maybe (UnconsListSRes f sh1) +listsUncons (ListS (WrapJust x ::% sh')) + | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsListSRes (ListS sh') x) +listsUncons (ListS ZX) = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to @@ -188,11 +208,11 @@ listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh + item -> item ::$ listsPermute is sh -- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh) +listsIndex _ _ SZ (n ::$ _) = n listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh @@ -204,7 +224,7 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe -- * Shaped indices -- | An index into a shape-typed array. -type role IxS nominal representational +type role IxS nominal nominal type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) deriving (Eq, Ord, Generic) @@ -216,7 +236,7 @@ pattern ZIS = IxS ZS -- removed in a future release. pattern (:.$) :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) where i :.$ IxS shl = IxS (Const i ::$ shl) @@ -331,7 +351,7 @@ pattern ZSS = ShS ZS pattern (:$$) :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) where i :$$ ShS shl = ShS (i ::$ shl) @@ -414,7 +434,7 @@ shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) +shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh)) shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) shsPermutePrefix = coerce (listsPermutePrefix @SNat) |
