From 091d8e3cc4f150ea1eb48953db18f4a352c5179a Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Thu, 9 Apr 2026 22:24:45 +0200 Subject: Remove ListS --- src/Data/Array/Nested.hs | 1 - src/Data/Array/Nested/Convert.hs | 1 - src/Data/Array/Nested/Mixed/Shape.hs | 2 +- src/Data/Array/Nested/Shaped/Shape.hs | 202 +++++++++++----------------------- 4 files changed, 67 insertions(+), 139 deletions(-) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index f022fe0..ec81843 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -33,7 +33,6 @@ module Data.Array.Nested ( -- * Shaped arrays Shaped(Shaped), - ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, ShS(.., ZSS, (:$$)), KnownShS(..), sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim, diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 6be0f74..2595c64 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -42,7 +42,6 @@ import Data.Coerce (coerce) import Data.Proxy import Data.Type.Equality import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 54e19e1..9869d03 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -649,7 +649,7 @@ shxFlatten = go (SNat @1) -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where +instance IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . IsList.fromList toList = Foldable.toList diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index d1a72ea..392ceac 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -42,137 +42,38 @@ import Data.Array.Nested.Permutation import Data.Array.Nested.Types --- * Shaped lists - -type role ListS nominal representational -type ListS :: [Nat] -> Type -> Type -newtype ListS sh i = ListS (ListX (MapJust sh) i) - deriving (Eq, Ord, NFData, Functor, Foldable) - -pattern ZS :: forall sh i. () => sh ~ '[] => ListS sh i -pattern ZS <- ListS (matchZX -> Just Refl) - where ZS = ListS ZX - -matchZX :: forall sh i. ListX (MapJust sh) i -> Maybe (sh :~: '[]) -matchZX ZX | Refl <- lemMapJustEmpty @sh Refl = Just Refl -matchZX _ = Nothing - -pattern (::$) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> ListS sh i -> ListS sh1 i -pattern i ::$ l <- (listsUncons -> Just (UnconsListSRes i l)) - where i ::$ ListS l = ListS (i ::% l) -infixr 3 ::$ - -data UnconsListSRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListSRes i (ListS sh i) -listsUncons :: forall sh1 i. ListS sh1 i -> Maybe (UnconsListSRes i sh1) -listsUncons (ListS (i ::% l)) | Refl <- lemMapJustHead (Proxy @sh1) - , Refl <- lemMapJustCons @sh1 Refl = - Just (UnconsListSRes i (ListS l)) -listsUncons (ListS _) = Nothing - -{-# COMPLETE ZS, (::$) #-} - -#ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance Show i => Show (ListS sh i) -#else -instance Show i => Show (ListS sh i) where - showsPrec _ = listsShow shows -#endif - -listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS -listsShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListS sh' i -> ShowS - go _ ZS = id - go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listsRank :: ListS sh i -> SNat (Rank sh) -listsRank ZS = SNat -listsRank (_ ::$ sh) = snatSucc (listsRank sh) - -{-# INLINE listsFromList #-} -listsFromList :: ShS sh -> [i] -> ListS sh i -listsFromList sh l = assert (shsLength sh == length l) - $ ListS $ IsList.fromList l - -{-# INLINE listsFromListS #-} -listsFromListS :: ListS sh i0 -> [i] -> ListS sh i -listsFromListS sh l = assert (length sh == length l) - $ ListS $ IsList.fromList l - -listsHead :: ListS (n : sh) i -> i -listsHead (i ::$ _) = i - -listsTail :: ListS (n : sh) i -> ListS sh i -listsTail (_ ::$ sh) = sh - -listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i -listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh -listsInit (_ ::$ ZS) = ZS - -listsLast :: ListS (n : sh) i -> i -listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh -listsLast (n ::$ ZS) = n - -listsAppend :: forall sh sh' i. ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i -listsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ - coerce (listxAppend @_ @_ @i) - -listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j) -listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js - -{-# INLINE listsZipWith #-} -listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k -listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js - -listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i -listsTakeLenPerm PNil _ = ZS -listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh -listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i -listsDropLenPerm PNil sh = sh -listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh -listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex i sh of - item -> item ::$ listsPermute is sh - -listsIndex :: forall j i sh. SNat i -> ListS sh j -> j -listsIndex SZ (n ::$ _) = n -listsIndex (SS i) (_ ::$ sh) = listsIndex i sh -listsIndex _ ZS = error "Index into empty shape" - -listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i -listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) - -- * Shaped indices -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh i) +newtype IxS sh i = IxS (IxX (MapJust sh) i) deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS +pattern ZIS <- IxS (matchZIX -> Just Refl) + where ZIS = IxS ZIX + +matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZIX _ = Nothing pattern (:.$) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i -pattern i :.$ l <- IxS (i ::$ (IxS -> l)) - where i :.$ IxS l = IxS (i ::$ l) +pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l)) + where i :.$ IxS l = IxS (i :.% l) infixr 3 :.$ +data UnconsIxSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i) +ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1) +ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1) + , Refl <- lemMapJustCons @sh1 Refl = + Just (UnconsIxSRes i (IxS l)) +ixsUncons (IxS _) = Nothing + {-# COMPLETE ZIS, (:.$) #-} -- For convenience, this contains regular 'Int's instead of bounded integers @@ -183,41 +84,55 @@ type IIxS sh = IxS sh Int deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow shows l + showsPrec _ l = ixsShow shows l #endif +ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS +ixsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> IxS sh' i -> ShowS + go _ ZIS = id + go prefix (x :.$ xs) = showString prefix . f x . go "," xs + ixsRank :: IxS sh i -> SNat (Rank sh) -ixsRank (IxS l) = listsRank l +ixsRank ZIS = SNat +ixsRank (_ :.$ sh) = snatSucc (ixsRank sh) -ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i -ixsFromList = coerce (listsFromList @_ @i) +{-# INLINE ixsFromList #-} +ixsFromList :: ShS sh -> [i] -> IxS sh i +ixsFromList sh l = assert (shsLength sh == length l) + $ IxS $ IsList.fromList l -{-# INLINEABLE ixsFromIxS #-} -ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i -ixsFromIxS = coerce (listsFromListS @_ @i0 @i) +{-# INLINE ixsFromIxS #-} +ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS sh l = assert (length sh == length l) + $ IxS $ IsList.fromList l ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = listsHead list +ixsHead (i :.$ _) = i ixsTail :: IxS (n : sh) i -> IxS sh i -ixsTail (IxS list) = IxS (listsTail list) +ixsTail (_ :.$ sh) = sh ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i -ixsInit (IxS list) = IxS (listsInit list) +ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh +ixsInit (_ :.$ ZIS) = ZIS ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = listsLast list +ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh +ixsLast (n :.$ ZIS) = n ixsCast :: IxS sh i -> IxS sh i ixsCast ZIS = ZIS ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @_ @i) +ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ + coerce (listxAppend @_ @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS @@ -228,8 +143,29 @@ ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js +ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i +ixsTakeLenPerm PNil _ = ZIS +ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh +ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i +ixsDropLenPerm PNil sh = sh +ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh +ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i +ixsPermute PNil _ = ZIS +ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) = + case ixsIndex i sh of + item -> item :.$ ixsPermute is sh + +ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j +ixsIndex SZ (n :.$ _) = n +ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh +ixsIndex _ ZIS = error "Index into empty shape" + ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @i) +ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh) -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. @@ -450,16 +386,10 @@ shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh i) where - type Item (ListS sh i) = i - fromList = listsFromList (knownShS @sh) - toList = Foldable.toList - -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShS sh => IsList (IxS sh i) where type Item (IxS sh i) = i - fromList = IxS . IsList.fromList + fromList = ixsFromList (knownShS @sh) toList = Foldable.toList -- | Untyped: length and values are checked at runtime. -- cgit v1.3