From fe034ff95a1f299ed140f37e416b5562cd423457 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 17 Dec 2025 19:06:38 +0100 Subject: Make List?, except ListH, less general --- src/Data/Array/Nested/Shaped/Shape.hs | 177 +++++++++++++--------------------- 1 file changed, 69 insertions(+), 108 deletions(-) (limited to 'src/Data/Array/Nested/Shaped') diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index afd2227..8cd937c 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -30,11 +30,7 @@ import Control.DeepSeq (NFData(..)) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality import GHC.Exts (build, withDict) import GHC.IsList (IsList) @@ -50,161 +46,141 @@ import Data.Array.Nested.Types -- * Shaped lists type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +type ListS :: [Nat] -> Type -> Type +data ListS sh i where + ZS :: ListS '[] i + (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i +deriving instance Eq i => Eq (ListS sh i) +deriving instance Ord i => Ord (ListS sh i) infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListS sh f) +deriving instance Show i => Show (ListS sh i) #else -instance (forall n. Show (f n)) => Show (ListS sh f) where +instance Show i => Show (ListS sh i) where showsPrec _ = listsShow shows #endif -instance (forall m. NFData (f m)) => NFData (ListS n f) where +instance NFData i => NFData (ListS n i) where rnf ZS = () rnf (x ::$ l) = rnf x `seq` rnf l -data UnconsListSRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +data UnconsListSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i +listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqType ZS ZS = Just Refl -listsEqType (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , Just Refl <- listsEqType sh sh' - = Just Refl -listsEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqual ZS ZS = Just Refl -listsEqual (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listsEqual sh sh' - = Just Refl -listsEqual _ _ = Nothing - -{-# INLINE listsFmap #-} -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -{-# INLINE listsFoldMap #-} -listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFoldMap _ ZS = mempty -listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListS sh' f -> ShowS + go :: String -> ListS sh' i -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs -listsLength :: ListS sh f -> Int -listsLength = getSum . listsFoldMap (\_ -> Sum 1) +instance Functor (ListS l) where + {-# INLINE fmap #-} + fmap _ ZS = ZS + fmap f (x ::$ xs) = f x ::$ fmap f xs + +instance Foldable (ListS l) where + {-# INLINE foldMap #-} + foldMap _ ZS = mempty + foldMap f (x ::$ xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZS = z + foldr f z (x ::$ xs) = f x (foldr f z xs) + toList = listsToList + null ZS = False + null _ = True + +listsLength :: ListS sh i -> Int +listsLength = length -listsRank :: ListS sh f -> SNat (Rank sh) +listsRank :: ListS sh i -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) -listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList :: ShS sh -> [i] -> ListS sh i listsFromList topsh topl = go topsh topl where - go :: ShS sh' -> [i] -> ListS sh' (Const i) + go :: ShS sh' -> [i] -> ListS sh' i go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go (_ :$$ sh) (i : is) = i ::$ go sh is go _ _ = error $ "listsFromList: Mismatched list length (type says " ++ show (shsLength topsh) ++ ", list has length " ++ show (length topl) ++ ")" {-# INLINEABLE listsFromListS #-} -listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) +listsFromListS :: ListS sh i0 -> [i] -> ListS sh i listsFromListS topl0 topl = go topl0 topl where - go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) + go :: ListS sh i0 -> [i] -> ListS sh i go ZS [] = ZS - go (_ ::$ l0) (i : is) = Const i ::$ go l0 is + go (_ ::$ l0) (i : is) = i ::$ go l0 is go _ _ = error $ "listsFromListS: Mismatched list length (the model says " ++ show (listsLength topl0) ++ ", list has length " ++ show (length topl) ++ ")" {-# INLINEABLE listsToList #-} -listsToList :: ListS sh (Const i) -> [i] +listsToList :: ListS sh i -> [i] listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListS sh (Const i) -> is + let go :: ListS sh i -> is go ZS = nil - go (Const i ::$ is) = i `cons` go is + go (i ::$ is) = i `cons` go is in go list) -listsHead :: ListS (n : sh) f -> f n +listsHead :: ListS (n : sh) i -> i listsHead (i ::$ _) = i -listsTail :: ListS (n : sh) f -> ListS sh f +listsTail :: ListS (n : sh) i -> ListS sh i listsTail (_ ::$ sh) = sh -listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh listsInit (_ ::$ ZS) = ZS -listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast :: ListS (n : sh) i -> i listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh listsLast (n ::$ ZS) = n -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js {-# INLINE listsZipWith #-} -listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g - -> ListS sh h +listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k listsZipWith _ ZS ZS = ZS listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex i sh of item -> item ::$ listsPermute is sh -- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed -listsIndex :: forall f i sh. SNat i -> ListS sh f -> f (Index i sh) +listsIndex :: forall j i sh. SNat i -> ListS sh j -> j listsIndex SZ (n ::$ _) = n -listsIndex (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex i sh +listsIndex (SS i) (_ ::$ sh) = listsIndex i sh listsIndex _ ZS = error "Index into empty shape" -listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- * Shaped indices @@ -212,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, NFData) +newtype IxS sh i = IxS (ListS sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -224,8 +200,8 @@ pattern (:.$) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i)) + where i :.$ IxS shl = IxS (i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} @@ -238,23 +214,9 @@ type IIxS sh = IxS sh Int deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + showsPrec _ (IxS l) = listsShow (\i -> shows i) l #endif -instance Functor (IxS sh) where - {-# INLINE fmap #-} - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - {-# INLINE foldMap #-} - foldMap f (IxS l) = listsFoldMap (f . getConst) l - {-# INLINE foldr #-} - foldr _ z ZIS = z - foldr f z (x :.$ xs) = f x (foldr f z xs) - toList = ixsToList - null ZIS = False - null _ = True - ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -268,16 +230,15 @@ ixsFromList = coerce (listsFromList @_ @i) ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i ixsFromIxS = coerce (listsFromListS @_ @i0 @i) -{-# INLINEABLE ixsToList #-} -ixsToList :: forall sh i. IxS sh i -> [i] -ixsToList = coerce (listsToList @_ @i) +ixsToList :: IxS sh i -> [i] +ixsToList = Foldable.toList ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = getConst (listsHead list) +ixsHead (IxS list) = listsHead list ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) @@ -286,14 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = getConst (listsLast list) +ixsLast (IxS list) = listsLast list ixsCast :: IxS sh i -> IxS sh i ixsCast ZIS = ZIS ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @(Const i)) +ixsAppend = coerce (listsAppend @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS @@ -305,7 +266,7 @@ ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +ixsPermutePrefix = coerce (listsPermutePrefix @i) -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. @@ -519,8 +480,8 @@ shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict -- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i +instance KnownShS sh => IsList (ListS sh i) where + type Item (ListS sh i) = i fromList = listsFromList (knownShS @sh) toList = listsToList -- cgit v1.2.3-70-g09d2