aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs177
1 files changed, 69 insertions, 108 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index afd2227..8cd937c 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -30,11 +30,7 @@ import Control.DeepSeq (NFData(..))
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product qualified as Fun
import Data.Kind (Constraint, Type)
-import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
@@ -50,161 +46,141 @@ import Data.Array.Nested.Types
-- * Shaped lists
type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+type ListS :: [Nat] -> Type -> Type
+data ListS sh i where
+ ZS :: ListS '[] i
+ (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i
+deriving instance Eq i => Eq (ListS sh i)
+deriving instance Ord i => Ord (ListS sh i)
infixr 3 ::$
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+deriving instance Show i => Show (ListS sh i)
#else
-instance (forall n. Show (f n)) => Show (ListS sh f) where
+instance Show i => Show (ListS sh i) where
showsPrec _ = listsShow shows
#endif
-instance (forall m. NFData (f m)) => NFData (ListS n f) where
+instance NFData i => NFData (ListS n i) where
rnf ZS = ()
rnf (x ::$ l) = rnf x `seq` rnf l
-data UnconsListSRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+data UnconsListSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i
+listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqType ZS ZS = Just Refl
-listsEqType (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listsEqType sh sh'
- = Just Refl
-listsEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqual ZS ZS = Just Refl
-listsEqual (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listsEqual sh sh'
- = Just Refl
-listsEqual _ _ = Nothing
-
-{-# INLINE listsFmap #-}
-listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-listsFmap _ ZS = ZS
-listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-
-{-# INLINE listsFoldMap #-}
-listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFoldMap _ ZS = mempty
-listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs
-
-listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
where
- go :: String -> ListS sh' f -> ShowS
+ go :: String -> ListS sh' i -> ShowS
go _ ZS = id
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFoldMap (\_ -> Sum 1)
+instance Functor (ListS l) where
+ {-# INLINE fmap #-}
+ fmap _ ZS = ZS
+ fmap f (x ::$ xs) = f x ::$ fmap f xs
-listsRank :: ListS sh f -> SNat (Rank sh)
+instance Foldable (ListS l) where
+ {-# INLINE foldMap #-}
+ foldMap _ ZS = mempty
+ foldMap f (x ::$ xs) = f x <> foldMap f xs
+ {-# INLINE foldr #-}
+ foldr _ z ZS = z
+ foldr f z (x ::$ xs) = f x (foldr f z xs)
+ toList = listsToList
+ null ZS = False
+ null _ = True
+
+listsLength :: ListS sh i -> Int
+listsLength = length
+
+listsRank :: ListS sh i -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
-listsFromList :: ShS sh -> [i] -> ListS sh (Const i)
+listsFromList :: ShS sh -> [i] -> ListS sh i
listsFromList topsh topl = go topsh topl
where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go :: ShS sh' -> [i] -> ListS sh' i
go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go (_ :$$ sh) (i : is) = i ::$ go sh is
go _ _ = error $ "listsFromList: Mismatched list length (type says "
++ show (shsLength topsh) ++ ", list has length "
++ show (length topl) ++ ")"
{-# INLINEABLE listsFromListS #-}
-listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i)
+listsFromListS :: ListS sh i0 -> [i] -> ListS sh i
listsFromListS topl0 topl = go topl0 topl
where
- go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i)
+ go :: ListS sh i0 -> [i] -> ListS sh i
go ZS [] = ZS
- go (_ ::$ l0) (i : is) = Const i ::$ go l0 is
+ go (_ ::$ l0) (i : is) = i ::$ go l0 is
go _ _ = error $ "listsFromListS: Mismatched list length (the model says "
++ show (listsLength topl0) ++ ", list has length "
++ show (length topl) ++ ")"
{-# INLINEABLE listsToList #-}
-listsToList :: ListS sh (Const i) -> [i]
+listsToList :: ListS sh i -> [i]
listsToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: ListS sh (Const i) -> is
+ let go :: ListS sh i -> is
go ZS = nil
- go (Const i ::$ is) = i `cons` go is
+ go (i ::$ is) = i `cons` go is
in go list)
-listsHead :: ListS (n : sh) f -> f n
+listsHead :: ListS (n : sh) i -> i
listsHead (i ::$ _) = i
-listsTail :: ListS (n : sh) f -> ListS sh f
+listsTail :: ListS (n : sh) i -> ListS sh i
listsTail (_ ::$ sh) = sh
-listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i
listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
listsInit (_ ::$ ZS) = ZS
-listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast :: ListS (n : sh) i -> i
listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
listsLast (n ::$ ZS) = n
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
+listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j)
listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js
+listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js
{-# INLINE listsZipWith #-}
-listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
- -> ListS sh h
+listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k
listsZipWith _ ZS ZS = ZS
listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js
-listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i
listsDropLenPerm PNil sh = sh
listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
case listsIndex i sh of
item -> item ::$ listsPermute is sh
-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed
-listsIndex :: forall f i sh. SNat i -> ListS sh f -> f (Index i sh)
+listsIndex :: forall j i sh. SNat i -> ListS sh j -> j
listsIndex SZ (n ::$ _) = n
-listsIndex (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex i sh
+listsIndex (SS i) (_ ::$ sh) = listsIndex i sh
listsIndex _ ZS = error "Index into empty shape"
-listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-- * Shaped indices
@@ -212,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe
-- | An index into a shape-typed array.
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord, NFData)
+newtype IxS sh i = IxS (ListS sh i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
@@ -224,8 +200,8 @@ pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i))
+ where i :.$ IxS shl = IxS (i ::$ shl)
infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
@@ -238,23 +214,9 @@ type IIxS sh = IxS sh Int
deriving instance Show i => Show (IxS sh i)
#else
instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+ showsPrec _ (IxS l) = listsShow (\i -> shows i) l
#endif
-instance Functor (IxS sh) where
- {-# INLINE fmap #-}
- fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
-
-instance Foldable (IxS sh) where
- {-# INLINE foldMap #-}
- foldMap f (IxS l) = listsFoldMap (f . getConst) l
- {-# INLINE foldr #-}
- foldr _ z ZIS = z
- foldr f z (x :.$ xs) = f x (foldr f z xs)
- toList = ixsToList
- null ZIS = False
- null _ = True
-
ixsLength :: IxS sh i -> Int
ixsLength (IxS l) = listsLength l
@@ -268,16 +230,15 @@ ixsFromList = coerce (listsFromList @_ @i)
ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i
ixsFromIxS = coerce (listsFromListS @_ @i0 @i)
-{-# INLINEABLE ixsToList #-}
-ixsToList :: forall sh i. IxS sh i -> [i]
-ixsToList = coerce (listsToList @_ @i)
+ixsToList :: IxS sh i -> [i]
+ixsToList = Foldable.toList
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
ixsHead :: IxS (n : sh) i -> i
-ixsHead (IxS list) = getConst (listsHead list)
+ixsHead (IxS list) = listsHead list
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
@@ -286,14 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
-ixsLast (IxS list) = getConst (listsLast list)
+ixsLast (IxS list) = listsLast list
ixsCast :: IxS sh i -> IxS sh i
ixsCast ZIS = ZIS
ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
-ixsAppend = coerce (listsAppend @_ @(Const i))
+ixsAppend = coerce (listsAppend @_ @i)
ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
@@ -305,7 +266,7 @@ ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+ixsPermutePrefix = coerce (listsPermutePrefix @i)
-- | Given a multidimensional index, get the corresponding linear
-- index into the buffer.
@@ -519,8 +480,8 @@ shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
-- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
+instance KnownShS sh => IsList (ListS sh i) where
+ type Item (ListS sh i) = i
fromList = listsFromList (knownShS @sh)
toList = listsToList