aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs202
1 files changed, 66 insertions, 136 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index d1a72ea..392ceac 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -42,137 +42,38 @@ import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
--- * Shaped lists
-
-type role ListS nominal representational
-type ListS :: [Nat] -> Type -> Type
-newtype ListS sh i = ListS (ListX (MapJust sh) i)
- deriving (Eq, Ord, NFData, Functor, Foldable)
-
-pattern ZS :: forall sh i. () => sh ~ '[] => ListS sh i
-pattern ZS <- ListS (matchZX -> Just Refl)
- where ZS = ListS ZX
-
-matchZX :: forall sh i. ListX (MapJust sh) i -> Maybe (sh :~: '[])
-matchZX ZX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
-matchZX _ = Nothing
-
-pattern (::$)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => i -> ListS sh i -> ListS sh1 i
-pattern i ::$ l <- (listsUncons -> Just (UnconsListSRes i l))
- where i ::$ ListS l = ListS (i ::% l)
-infixr 3 ::$
-
-data UnconsListSRes i sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListSRes i (ListS sh i)
-listsUncons :: forall sh1 i. ListS sh1 i -> Maybe (UnconsListSRes i sh1)
-listsUncons (ListS (i ::% l)) | Refl <- lemMapJustHead (Proxy @sh1)
- , Refl <- lemMapJustCons @sh1 Refl =
- Just (UnconsListSRes i (ListS l))
-listsUncons (ListS _) = Nothing
-
-{-# COMPLETE ZS, (::$) #-}
-
-#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance Show i => Show (ListS sh i)
-#else
-instance Show i => Show (ListS sh i) where
- showsPrec _ = listsShow shows
-#endif
-
-listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS
-listsShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListS sh' i -> ShowS
- go _ ZS = id
- go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-
-listsRank :: ListS sh i -> SNat (Rank sh)
-listsRank ZS = SNat
-listsRank (_ ::$ sh) = snatSucc (listsRank sh)
-
-{-# INLINE listsFromList #-}
-listsFromList :: ShS sh -> [i] -> ListS sh i
-listsFromList sh l = assert (shsLength sh == length l)
- $ ListS $ IsList.fromList l
-
-{-# INLINE listsFromListS #-}
-listsFromListS :: ListS sh i0 -> [i] -> ListS sh i
-listsFromListS sh l = assert (length sh == length l)
- $ ListS $ IsList.fromList l
-
-listsHead :: ListS (n : sh) i -> i
-listsHead (i ::$ _) = i
-
-listsTail :: ListS (n : sh) i -> ListS sh i
-listsTail (_ ::$ sh) = sh
-
-listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i
-listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
-listsInit (_ ::$ ZS) = ZS
-
-listsLast :: ListS (n : sh) i -> i
-listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
-listsLast (n ::$ ZS) = n
-
-listsAppend :: forall sh sh' i. ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i
-listsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $
- coerce (listxAppend @_ @_ @i)
-
-listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j)
-listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js
-
-{-# INLINE listsZipWith #-}
-listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k
-listsZipWith _ ZS ZS = ZS
-listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js
-
-listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i
-listsTakeLenPerm PNil _ = ZS
-listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
-listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i
-listsDropLenPerm PNil sh = sh
-listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
-listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i
-listsPermute PNil _ = ZS
-listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex i sh of
- item -> item ::$ listsPermute is sh
-
-listsIndex :: forall j i sh. SNat i -> ListS sh j -> j
-listsIndex SZ (n ::$ _) = n
-listsIndex (SS i) (_ ::$ sh) = listsIndex i sh
-listsIndex _ ZS = error "Index into empty shape"
-
-listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i
-listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-
-- * Shaped indices
-- | An index into a shape-typed array.
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh i)
+newtype IxS sh i = IxS (IxX (MapJust sh) i)
deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
+pattern ZIS <- IxS (matchZIX -> Just Refl)
+ where ZIS = IxS ZIX
+
+matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZIX _ = Nothing
pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
-pattern i :.$ l <- IxS (i ::$ (IxS -> l))
- where i :.$ IxS l = IxS (i ::$ l)
+pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l))
+ where i :.$ IxS l = IxS (i :.% l)
infixr 3 :.$
+data UnconsIxSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i)
+ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1)
+ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1)
+ , Refl <- lemMapJustCons @sh1 Refl =
+ Just (UnconsIxSRes i (IxS l))
+ixsUncons (IxS _) = Nothing
+
{-# COMPLETE ZIS, (:.$) #-}
-- For convenience, this contains regular 'Int's instead of bounded integers
@@ -183,41 +84,55 @@ type IIxS sh = IxS sh Int
deriving instance Show i => Show (IxS sh i)
#else
instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow shows l
+ showsPrec _ l = ixsShow shows l
#endif
+ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS
+ixsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> IxS sh' i -> ShowS
+ go _ ZIS = id
+ go prefix (x :.$ xs) = showString prefix . f x . go "," xs
+
ixsRank :: IxS sh i -> SNat (Rank sh)
-ixsRank (IxS l) = listsRank l
+ixsRank ZIS = SNat
+ixsRank (_ :.$ sh) = snatSucc (ixsRank sh)
-ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
-ixsFromList = coerce (listsFromList @_ @i)
+{-# INLINE ixsFromList #-}
+ixsFromList :: ShS sh -> [i] -> IxS sh i
+ixsFromList sh l = assert (shsLength sh == length l)
+ $ IxS $ IsList.fromList l
-{-# INLINEABLE ixsFromIxS #-}
-ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i
-ixsFromIxS = coerce (listsFromListS @_ @i0 @i)
+{-# INLINE ixsFromIxS #-}
+ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS sh l = assert (length sh == length l)
+ $ IxS $ IsList.fromList l
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
ixsHead :: IxS (n : sh) i -> i
-ixsHead (IxS list) = listsHead list
+ixsHead (i :.$ _) = i
ixsTail :: IxS (n : sh) i -> IxS sh i
-ixsTail (IxS list) = IxS (listsTail list)
+ixsTail (_ :.$ sh) = sh
ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
-ixsInit (IxS list) = IxS (listsInit list)
+ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh
+ixsInit (_ :.$ ZIS) = ZIS
ixsLast :: IxS (n : sh) i -> i
-ixsLast (IxS list) = listsLast list
+ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh
+ixsLast (n :.$ ZIS) = n
ixsCast :: IxS sh i -> IxS sh i
ixsCast ZIS = ZIS
ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
-ixsAppend = coerce (listsAppend @_ @_ @i)
+ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $
+ coerce (listxAppend @_ @_ @i)
ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
@@ -228,8 +143,29 @@ ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k
ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
+ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i
+ixsTakeLenPerm PNil _ = ZIS
+ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh
+ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape"
+
+ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i
+ixsDropLenPerm PNil sh = sh
+ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh
+ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape"
+
+ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i
+ixsPermute PNil _ = ZIS
+ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) =
+ case ixsIndex i sh of
+ item -> item :.$ ixsPermute is sh
+
+ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j
+ixsIndex SZ (n :.$ _) = n
+ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh
+ixsIndex _ ZIS = error "Index into empty shape"
+
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-ixsPermutePrefix = coerce (listsPermutePrefix @i)
+ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh)
-- | Given a multidimensional index, get the corresponding linear
-- index into the buffer.
@@ -450,16 +386,10 @@ shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
--- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh i) where
- type Item (ListS sh i) = i
- fromList = listsFromList (knownShS @sh)
- toList = Foldable.toList
-
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShS sh => IsList (IxS sh i) where
type Item (IxS sh i) = i
- fromList = IxS . IsList.fromList
+ fromList = ixsFromList (knownShS @sh)
toList = Foldable.toList
-- | Untyped: length and values are checked at runtime.