diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 152 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 17 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 177 |
4 files changed, 136 insertions, 216 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 802c71e..c707f18 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -32,8 +32,6 @@ import Control.DeepSeq (NFData(..)) import Data.Bifunctor (first) import Data.Coerce import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy @@ -59,126 +57,104 @@ type family Rank sh where -- * Mixed lists to be used IxX and shaped and ranked lists and indexes type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: forall n sh {f}. f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +type ListX :: [Maybe Nat] -> Type -> Type +data ListX sh i where + ZX :: ListX '[] i + (::%) :: forall n sh {i}. i -> ListX sh i -> ListX (n : sh) i +deriving instance Eq i => Eq (ListX sh i) +deriving instance Ord i => Ord (ListX sh i) infixr 3 ::% #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListX sh f) +deriving instance Show i => Show (ListX sh i) #else -instance (forall n. Show (f n)) => Show (ListX sh f) where +instance Show i => Show (ListX sh i) where showsPrec _ = listxShow shows #endif -instance (forall n. NFData (f n)) => NFData (ListX sh f) where +instance NFData i => NFData (ListX sh i) where rnf ZX = () rnf (x ::% l) = rnf x `seq` rnf l -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +data UnconsListXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh i) i listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) listxUncons ZX = Nothing --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqType ZX ZX = Just Refl -listxEqType (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , Just Refl <- listxEqType sh sh' - = Just Refl -listxEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqual ZX ZX = Just Refl -listxEqual (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listxEqual sh sh' - = Just Refl -listxEqual _ _ = Nothing - -{-# INLINE listxFmap #-} -listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -listxFmap _ ZX = ZX -listxFmap f (x ::% xs) = f x ::% listxFmap f xs +instance Functor (ListX l) where + {-# INLINE fmap #-} + fmap _ ZX = ZX + fmap f (x ::% xs) = f x ::% fmap f xs -{-# INLINE listxFoldMap #-} -listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFoldMap _ ZX = mempty -listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs +instance Foldable (ListX l) where + {-# INLINE foldMap #-} + foldMap _ ZX = mempty + foldMap f (x ::% xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZX = z + foldr f z (x ::% xs) = f x (foldr f z xs) + toList = listxToList + null ZX = False + null _ = True -listxLength :: ListX sh f -> Int -listxLength = getSum . listxFoldMap (\_ -> Sum 1) +listxLength :: ListX sh i -> Int +listxLength = length -listxRank :: ListX sh f -> SNat (Rank sh) +listxRank :: ListX sh i -> SNat (Rank sh) listxRank ZX = SNat listxRank (_ ::% l) | SNat <- listxRank l = SNat {-# INLINE listxShow #-} -listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS listxShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListX sh' f -> ShowS + go :: String -> ListX sh' i -> ShowS go _ ZX = id go prefix (x ::% xs) = showString prefix . f x . go "," xs -listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) +listxFromList :: StaticShX sh -> [i] -> ListX sh i listxFromList topssh topl = go topssh topl where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go :: StaticShX sh' -> [i] -> ListX sh' i go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is + go (_ :!% sh) (i : is) = i ::% go sh is go _ _ = error $ "listxFromList: Mismatched list length (type says " ++ show (ssxLength topssh) ++ ", list has length " ++ show (length topl) ++ ")" {-# INLINEABLE listxToList #-} -listxToList :: ListX sh (Const i) -> [i] +listxToList :: ListX sh i -> [i] listxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListX sh (Const i) -> is + let go :: ListX sh i -> is go ZX = nil - go (Const i ::% is) = i `cons` go is + go (i ::% is) = i `cons` go is in go list) -listxHead :: ListX (mn ': sh) f -> f mn +listxHead :: ListX (mn ': sh) i -> i listxHead (i ::% _) = i listxTail :: ListX (n : sh) i -> ListX sh i listxTail (_ ::% sh) = sh -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend :: ListX sh i -> ListX sh' i -> ListX (sh ++ sh') i listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop :: forall i j sh sh'. ListX sh j -> ListX (sh ++ sh') i -> ListX sh' i listxDrop ZX long = long listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' -listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f +listxInit :: forall i n sh. ListX (n : sh) i -> ListX (Init (n : sh)) i listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh listxInit (_ ::% ZX) = ZX -listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast :: forall i n sh. ListX (n : sh) i -> i listxLast (_ ::% sh@(_ ::% _)) = listxLast sh listxLast (x ::% ZX) = x -listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) -listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest - {-# INLINE listxZipWith #-} -listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g - -> ListX sh h +listxZipWith :: (i -> j -> k) -> ListX sh i -> ListX sh j -> ListX sh k listxZipWith _ ZX ZX = ZX listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js @@ -188,8 +164,8 @@ listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js -- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, NFData) +newtype IxX sh i = IxX (ListX sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX @@ -198,8 +174,8 @@ pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) i)) + where i :.% IxX shl = IxX (i ::% shl) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} @@ -212,23 +188,9 @@ type IIxX sh = IxX sh Int deriving instance Show i => Show (IxX sh i) #else instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = listxShow (shows . getConst) l + showsPrec _ (IxX l) = listxShow shows l #endif -instance Functor (IxX sh) where - {-# INLINE fmap #-} - fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) - -instance Foldable (IxX sh) where - {-# INLINE foldMap #-} - foldMap f (IxX l) = listxFoldMap (f . getConst) l - {-# INLINE foldr #-} - foldr _ z ZIX = z - foldr f z (x :.% xs) = f x (foldr f z xs) - toList = ixxToList - null ZIX = False - null _ = True - ixxLength :: IxX sh i -> Int ixxLength (IxX l) = listxLength l @@ -243,30 +205,30 @@ ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh +{-# INLINEABLE ixxFromList #-} ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i ixxFromList = coerce (listxFromList @_ @i) -{-# INLINEABLE ixxToList #-} -ixxToList :: forall sh i. IxX sh i -> [i] -ixxToList = coerce (listxToList @_ @i) +ixxToList :: IxX sh i -> [i] +ixxToList = Foldable.toList ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = getConst (listxHead list) +ixxHead (IxX list) = listxHead list ixxTail :: IxX (n : sh) i -> IxX sh i ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @(Const i)) +ixxAppend = coerce (listxAppend @_ @i) ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i -ixxDrop = coerce (listxDrop @(Const i) @(Const i)) +ixxDrop = coerce (listxDrop @i @i) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @(Const i)) +ixxInit = coerce (listxInit @i) ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @(Const i)) +ixxLast = coerce (listxLast @i) ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i ixxCast ZKX ZIX = ZIX @@ -818,8 +780,8 @@ shxFlatten = go (SNat @1) -- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i +instance KnownShX sh => IsList (ListX sh i) where + type Item (ListX sh i) = i fromList = listxFromList (knownShX @sh) toList = listxToList diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index c3d2075..ecdb06d 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -18,7 +18,6 @@ module Data.Array.Nested.Permutation where import Data.Coerce (coerce) -import Data.Functor.Const import Data.List (sort) import Data.Maybe (fromMaybe) import Data.Proxy @@ -236,33 +235,31 @@ shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) shxPermutePrefix = coerce (listhPermutePrefix @Int) -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i listxTakeLen PNil _ = ZX listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen :: forall i is sh. Perm is -> ListX sh i -> ListX (DropLen is sh) i listxDropLen PNil sh = sh listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute :: forall i is sh. Perm is -> ListX sh i -> ListX (Permute is sh) i listxPermute PNil _ = ZX listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = listxIndex i sh ::% listxPermute is sh -listxIndex :: forall f i sh. SNat i -> ListX sh f -> f (Index i sh) +listxIndex :: forall j i sh. SNat i -> ListX sh j -> j listxIndex SZ (n ::% _) = n -listxIndex (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex i sh +listxIndex (SS i) (_ ::% sh) = listxIndex i sh listxIndex _ ZX = error "Index into empty shape" -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix :: forall i is sh. Perm is -> ListX sh i -> ListX (PermutePrefix is sh) i listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) +ixxPermutePrefix = coerce (listxPermutePrefix @i) -- * Operations on permutations diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 2f20e1a..2415e26 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -250,12 +250,12 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n +{-# INLINEABLE ixrFromList #-} ixrFromList :: forall n i. SNat n -> [i] -> IxR n i ixrFromList = coerce (listrFromList @_ @i) -{-# INLINEABLE ixrToList #-} -ixrToList :: forall n i. IxR n i -> [i] -ixrToList = coerce (listrToList @_ @i) +ixrToList :: IxR n i -> [i] +ixrToList = Foldable.toList ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index afd2227..8cd937c 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -30,11 +30,7 @@ import Control.DeepSeq (NFData(..)) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality import GHC.Exts (build, withDict) import GHC.IsList (IsList) @@ -50,161 +46,141 @@ import Data.Array.Nested.Types -- * Shaped lists type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +type ListS :: [Nat] -> Type -> Type +data ListS sh i where + ZS :: ListS '[] i + (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i +deriving instance Eq i => Eq (ListS sh i) +deriving instance Ord i => Ord (ListS sh i) infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListS sh f) +deriving instance Show i => Show (ListS sh i) #else -instance (forall n. Show (f n)) => Show (ListS sh f) where +instance Show i => Show (ListS sh i) where showsPrec _ = listsShow shows #endif -instance (forall m. NFData (f m)) => NFData (ListS n f) where +instance NFData i => NFData (ListS n i) where rnf ZS = () rnf (x ::$ l) = rnf x `seq` rnf l -data UnconsListSRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +data UnconsListSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i +listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqType ZS ZS = Just Refl -listsEqType (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , Just Refl <- listsEqType sh sh' - = Just Refl -listsEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqual ZS ZS = Just Refl -listsEqual (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listsEqual sh sh' - = Just Refl -listsEqual _ _ = Nothing - -{-# INLINE listsFmap #-} -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -{-# INLINE listsFoldMap #-} -listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFoldMap _ ZS = mempty -listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListS sh' f -> ShowS + go :: String -> ListS sh' i -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs -listsLength :: ListS sh f -> Int -listsLength = getSum . listsFoldMap (\_ -> Sum 1) +instance Functor (ListS l) where + {-# INLINE fmap #-} + fmap _ ZS = ZS + fmap f (x ::$ xs) = f x ::$ fmap f xs -listsRank :: ListS sh f -> SNat (Rank sh) +instance Foldable (ListS l) where + {-# INLINE foldMap #-} + foldMap _ ZS = mempty + foldMap f (x ::$ xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZS = z + foldr f z (x ::$ xs) = f x (foldr f z xs) + toList = listsToList + null ZS = False + null _ = True + +listsLength :: ListS sh i -> Int +listsLength = length + +listsRank :: ListS sh i -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) -listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList :: ShS sh -> [i] -> ListS sh i listsFromList topsh topl = go topsh topl where - go :: ShS sh' -> [i] -> ListS sh' (Const i) + go :: ShS sh' -> [i] -> ListS sh' i go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go (_ :$$ sh) (i : is) = i ::$ go sh is go _ _ = error $ "listsFromList: Mismatched list length (type says " ++ show (shsLength topsh) ++ ", list has length " ++ show (length topl) ++ ")" {-# INLINEABLE listsFromListS #-} -listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) +listsFromListS :: ListS sh i0 -> [i] -> ListS sh i listsFromListS topl0 topl = go topl0 topl where - go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) + go :: ListS sh i0 -> [i] -> ListS sh i go ZS [] = ZS - go (_ ::$ l0) (i : is) = Const i ::$ go l0 is + go (_ ::$ l0) (i : is) = i ::$ go l0 is go _ _ = error $ "listsFromListS: Mismatched list length (the model says " ++ show (listsLength topl0) ++ ", list has length " ++ show (length topl) ++ ")" {-# INLINEABLE listsToList #-} -listsToList :: ListS sh (Const i) -> [i] +listsToList :: ListS sh i -> [i] listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListS sh (Const i) -> is + let go :: ListS sh i -> is go ZS = nil - go (Const i ::$ is) = i `cons` go is + go (i ::$ is) = i `cons` go is in go list) -listsHead :: ListS (n : sh) f -> f n +listsHead :: ListS (n : sh) i -> i listsHead (i ::$ _) = i -listsTail :: ListS (n : sh) f -> ListS sh f +listsTail :: ListS (n : sh) i -> ListS sh i listsTail (_ ::$ sh) = sh -listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh listsInit (_ ::$ ZS) = ZS -listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast :: ListS (n : sh) i -> i listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh listsLast (n ::$ ZS) = n -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js {-# INLINE listsZipWith #-} -listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g - -> ListS sh h +listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k listsZipWith _ ZS ZS = ZS listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex i sh of item -> item ::$ listsPermute is sh -- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed -listsIndex :: forall f i sh. SNat i -> ListS sh f -> f (Index i sh) +listsIndex :: forall j i sh. SNat i -> ListS sh j -> j listsIndex SZ (n ::$ _) = n -listsIndex (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex i sh +listsIndex (SS i) (_ ::$ sh) = listsIndex i sh listsIndex _ ZS = error "Index into empty shape" -listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- * Shaped indices @@ -212,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, NFData) +newtype IxS sh i = IxS (ListS sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -224,8 +200,8 @@ pattern (:.$) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i)) + where i :.$ IxS shl = IxS (i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} @@ -238,23 +214,9 @@ type IIxS sh = IxS sh Int deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + showsPrec _ (IxS l) = listsShow (\i -> shows i) l #endif -instance Functor (IxS sh) where - {-# INLINE fmap #-} - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - {-# INLINE foldMap #-} - foldMap f (IxS l) = listsFoldMap (f . getConst) l - {-# INLINE foldr #-} - foldr _ z ZIS = z - foldr f z (x :.$ xs) = f x (foldr f z xs) - toList = ixsToList - null ZIS = False - null _ = True - ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -268,16 +230,15 @@ ixsFromList = coerce (listsFromList @_ @i) ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i ixsFromIxS = coerce (listsFromListS @_ @i0 @i) -{-# INLINEABLE ixsToList #-} -ixsToList :: forall sh i. IxS sh i -> [i] -ixsToList = coerce (listsToList @_ @i) +ixsToList :: IxS sh i -> [i] +ixsToList = Foldable.toList ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = getConst (listsHead list) +ixsHead (IxS list) = listsHead list ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) @@ -286,14 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = getConst (listsLast list) +ixsLast (IxS list) = listsLast list ixsCast :: IxS sh i -> IxS sh i ixsCast ZIS = ZIS ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @(Const i)) +ixsAppend = coerce (listsAppend @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS @@ -305,7 +266,7 @@ ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +ixsPermutePrefix = coerce (listsPermutePrefix @i) -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. @@ -519,8 +480,8 @@ shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict -- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i +instance KnownShS sh => IsList (ListS sh i) where + type Item (ListS sh i) = i fromList = listsFromList (knownShS @sh) toList = listsToList |
