aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs737
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs4
6 files changed, 9 insertions, 746 deletions
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
index c316161..5d6cee4 100644
--- a/src/Data/Array/Nested/Internal/Convert.hs
+++ b/src/Data/Array/Nested/Internal/Convert.hs
@@ -12,12 +12,12 @@ import Data.Proxy
import Data.Type.Equality
import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Internal.Shaped
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
index f894f78..b8baf96 100644
--- a/src/Data/Array/Nested/Internal/Lemmas.hs
+++ b/src/Data/Array/Nested/Internal/Lemmas.hs
@@ -11,9 +11,9 @@ import GHC.TypeLits
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Shaped.Shape
lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index a2f9737..b76aa50 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -45,7 +45,7 @@ import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index daf0374..368e337 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -40,12 +40,12 @@ import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
deleted file mode 100644
index 97b9456..0000000
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ /dev/null
@@ -1,737 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE RoleAnnotations #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Shape where
-
-import Control.DeepSeq (NFData(..))
-import Data.Array.Mixed.Types
-import Data.Array.Shape qualified as O
-import Data.Coerce (coerce)
-import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product qualified as Fun
-import Data.Kind (Constraint, Type)
-import Data.Monoid (Sum(..))
-import Data.Proxy
-import Data.Type.Equality
-import GHC.Exts (withDict)
-import GHC.Generics (Generic)
-import GHC.IsList (IsList)
-import GHC.IsList qualified as IsList
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
-infixr 3 :::
-
-instance Show i => Show (ListR n i) where
- showsPrec _ = listrShow shows
-
-instance NFData i => NFData (ListR n i) where
- rnf ZR = ()
- rnf (x ::: l) = rnf x `seq` rnf l
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqRank ZR ZR = Just Refl
-listrEqRank (_ ::: sh) (_ ::: sh')
- | Just Refl <- listrEqRank sh sh'
- = Just Refl
-listrEqRank _ _ = Nothing
-
--- | This compares the lists for value equality.
-listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqual ZR ZR = Just Refl
-listrEqual (i ::: sh) (j ::: sh')
- | Just Refl <- listrEqual sh sh'
- , i == j
- = Just Refl
-listrEqual _ _ = Nothing
-
-listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
-listrShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR n' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrLength :: ListR n i -> Int
-listrLength = length
-
-listrRank :: ListR n i -> SNat n
-listrRank ZR = SNat
-listrRank (_ ::: sh) = snatSucc (listrRank sh)
-
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend ZR sh = sh
-listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
-
-listrHead :: ListR (n + 1) i -> i
-listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
-
-listrTail :: ListR (n + 1) i -> ListR n i
-listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
-
-listrInit :: ListR (n + 1) i -> ListR n i
-listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
-listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
-
-listrLast :: ListR (n + 1) i -> i
-listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
-listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
-listrZip ZR ZR = ZR
-listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
-listrZip _ _ = error "listrZip: impossible pattern needlessly required"
-
-listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
-listrZipWith _ ZR ZR = ZR
-listrZipWith f (i ::: irest) (j ::: jrest) =
- f i j ::: listrZipWith f irest jrest
-listrZipWith _ _ _ =
- error "listrZipWith: impossible pattern needlessly required"
-
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "listrPermutePrefix: Index in permutation out of range"
-
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
- where i :.: IxR sh = IxR (i ::: sh)
-infixr 3 :.:
-
-{-# COMPLETE ZIR, (:.:) #-}
-
-type IIxR n = IxR n Int
-
-instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = listrShow shows l
-
-instance NFData i => NFData (IxR sh i)
-
-ixrLength :: IxR sh i -> Int
-ixrLength (IxR l) = listrLength l
-
-ixrRank :: IxR n i -> SNat n
-ixrRank (IxR sh) = listrRank sh
-
-ixrZero :: SNat n -> IIxR n
-ixrZero SZ = ZIR
-ixrZero (SS n) = 0 :.: ixrZero n
-
-ixCvtXR :: IIxX sh -> IIxR (Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixCvtRX idx)
-
-ixrHead :: IxR (n + 1) i -> i
-ixrHead (IxR list) = listrHead list
-
-ixrTail :: IxR (n + 1) i -> IxR n i
-ixrTail (IxR list) = IxR (listrTail list)
-
-ixrInit :: IxR (n + 1) i -> IxR n i
-ixrInit (IxR list) = IxR (listrInit list)
-
-ixrLast :: IxR (n + 1) i -> i
-ixrLast (IxR list) = listrLast list
-
-ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @_ @i)
-
-ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
-ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
-
-ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
-ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
-type role ShR nominal representational
-type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
-
-pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
-infixr 3 :$:
-
-{-# COMPLETE ZSR, (:$:) #-}
-
-type IShR n = ShR n Int
-
-instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
-
-instance NFData i => NFData (ShR sh i)
-
-shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
-shCvtXR' ZSX =
- castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
- ZSR
-shCvtXR' (n :$% (idx :: IShX sh))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
- castWith (subst2 (lem1 @sh Refl))
- (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
- where
- lem1 :: forall sh' n' k.
- k : sh' :~: Replicate n' Nothing
- -> Rank sh' + 1 :~: n'
- lem1 Refl = unsafeCoerceRefl
-
- lem2 :: k : sh :~: Replicate n Nothing
- -> sh :~: Replicate (Rank sh) Nothing
- lem2 Refl = unsafeCoerceRefl
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shCvtRX idx)
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
-
--- | This compares the shapes for value equality.
-shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
-
-shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
-
--- | This function can also be used to conjure up a 'KnownNat' dictionary;
--- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
--- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
-
--- | The number of elements in an array described by this shape.
-shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
-
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
-
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
-
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
-
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
-
-shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
-
-shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
-
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList topl = go (SNat @n) topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
- ++ show (fromSNat (SNat @n)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (IxR n i) where
- type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
-
-
-type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
-infixr 3 ::$
-
-instance (forall n. Show (f n)) => Show (ListS sh f) where
- showsPrec _ = listsShow shows
-
-instance (forall m. NFData (f m)) => NFData (ListS n f) where
- rnf ZS = ()
- rnf (x ::$ l) = rnf x `seq` rnf l
-
-data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
-listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
-listsUncons ZS = Nothing
-
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqType ZS ZS = Just Refl
-listsEqType (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listsEqType sh sh'
- = Just Refl
-listsEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqual ZS ZS = Just Refl
-listsEqual (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listsEqual sh sh'
- = Just Refl
-listsEqual _ _ = Nothing
-
-listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-listsFmap _ ZS = ZS
-listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-
-listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFold _ ZS = mempty
-listsFold f (x ::$ xs) = f x <> listsFold f xs
-
-listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
-listsShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListS sh' f -> ShowS
- go _ ZS = id
- go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-
-listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFold (\_ -> Sum 1)
-
-listsRank :: ListS sh f -> SNat (Rank sh)
-listsRank ZS = SNat
-listsRank (_ ::$ sh) = snatSucc (listsRank sh)
-
-listsToList :: ListS sh (Const i) -> [i]
-listsToList ZS = []
-listsToList (Const i ::$ is) = i : listsToList is
-
-listsHead :: ListS (n : sh) f -> f n
-listsHead (i ::$ _) = i
-
-listsTail :: ListS (n : sh) f -> ListS sh f
-listsTail (_ ::$ sh) = sh
-
-listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
-listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
-listsInit (_ ::$ ZS) = ZS
-
-listsLast :: ListS (n : sh) f -> f (Last (n : sh))
-listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
-listsLast (n ::$ ZS) = n
-
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
-listsAppend ZS idx' = idx'
-listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-
-listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
-listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) =
- Fun.Pair i j ::$ listsZip is js
-
-listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
- -> ListS sh h
-listsZipWith _ ZS ZS = ZS
-listsZipWith f (i ::$ is) (j ::$ js) =
- f i j ::$ listsZipWith f is js
-
-listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
-listsTakeLenPerm PNil _ = ZS
-listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
-listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
-listsDropLenPerm PNil sh = sh
-listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
-listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
-listsPermute PNil _ = ZS
-listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
-
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh
-listsIndex _ _ _ ZS = error "Index into empty shape"
-
-listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
-listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-
-
--- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\").
-type role IxS nominal representational
-type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord, Generic)
-
-pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
-
-pattern (:.$)
- :: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
- => i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
-infixr 3 :.$
-
-{-# COMPLETE ZIS, (:.$) #-}
-
-type IIxS sh = IxS sh Int
-
-instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
-
-instance Functor (IxS sh) where
- fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
-
-instance Foldable (IxS sh) where
- foldMap f (IxS l) = listsFold (f . getConst) l
-
-instance NFData i => NFData (IxS sh i)
-
-ixsLength :: IxS sh i -> Int
-ixsLength (IxS l) = listsLength l
-
-ixsRank :: IxS sh i -> SNat (Rank sh)
-ixsRank (IxS l) = listsRank l
-
-ixsZero :: ShS sh -> IIxS sh
-ixsZero ZSS = ZIS
-ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-
-ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
-
-ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
-ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
-
-ixsHead :: IxS (n : sh) i -> i
-ixsHead (IxS list) = getConst (listsHead list)
-
-ixsTail :: IxS (n : sh) i -> IxS sh i
-ixsTail (IxS list) = IxS (listsTail list)
-
-ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
-ixsInit (IxS list) = IxS (listsInit list)
-
-ixsLast :: IxS (n : sh) i -> i
-ixsLast (IxS list) = getConst (listsLast list)
-
-ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
-ixsAppend = coerce (listsAppend @_ @(Const i))
-
-ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
-ixsZip ZIS ZIS = ZIS
-ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
-
-ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
-ixsZipWith _ ZIS ZIS = ZIS
-ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
-
-ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
-
-
--- | The shape of a shape-typed array given as a list of 'SNat' values.
---
--- Note that because the shape of a shape-typed array is known statically, you
--- can also retrieve the array shape from a 'KnownShS' dictionary.
-type role ShS nominal
-type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Eq, Ord, Generic)
-
-pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
-
-pattern (:$$)
- :: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
- => SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
-
-infixr 3 :$$
-
-{-# COMPLETE ZSS, (:$$) #-}
-
-instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
-
-instance NFData (ShS sh) where
- rnf (ShS ZS) = ()
- rnf (ShS (SNat ::$ l)) = rnf (ShS l)
-
-instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
-
--- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
--- equal if and only if values are equal.)
-shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
-shsEqual = testEquality
-
-shsLength :: ShS sh -> Int
-shsLength (ShS l) = listsLength l
-
-shsRank :: ShS sh -> SNat (Rank sh)
-shsRank (ShS l) = listsRank l
-
-shsSize :: ShS sh -> Int
-shsSize ZSS = 1
-shsSize (n :$$ sh) = fromSNat' n * shsSize sh
-
-shsToList :: ShS sh -> [Int]
-shsToList ZSS = []
-shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-
-shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
-shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shCvtXS' (SUnknown _ :$% _) = error "impossible"
-
-shCvtSX :: ShS sh -> IShX (MapJust sh)
-shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
-
-shsHead :: ShS (n : sh) -> SNat n
-shsHead (ShS list) = listsHead list
-
-shsTail :: ShS (n : sh) -> ShS sh
-shsTail (ShS list) = ShS (listsTail list)
-
-shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
-shsInit (ShS list) = ShS (listsInit list)
-
-shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
-shsLast (ShS list) = listsLast list
-
-shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
-shsAppend = coerce (listsAppend @_ @SNat)
-
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
-
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
-
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
-
-shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-shsPermutePrefix = coerce (listsPermutePrefix @SNat)
-
-type family Product sh where
- Product '[] = 1
- Product (n : ns) = n * Product ns
-
-shsProduct :: ShS sh -> SNat (Product sh)
-shsProduct ZSS = SNat
-shsProduct (n :$$ sh) = n `snatMul` shsProduct sh
-
--- | Evidence for the static part of a shape. This pops up only when you are
--- polymorphic in the element type of an array.
-type KnownShS :: [Nat] -> Constraint
-class KnownShS sh where knownShS :: ShS sh
-instance KnownShS '[] where knownShS = ZSS
-instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
-
-withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
-withKnownShS k = withDict @(KnownShS sh) k
-
-shsKnownShS :: ShS sh -> Dict KnownShS sh
-shsKnownShS ZSS = Dict
-shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict
-
-shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
-shsOrthotopeShape ZSS = Dict
-shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
-
-
--- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = listsToList
-
--- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
-instance KnownShS sh => IsList (IxS sh i) where
- type Item (IxS sh i) = i
- fromList = IxS . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length and values are checked at runtime.
-instance KnownShS sh => IsList (ShS sh) where
- type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = shsToList
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 372439f..86dcee2 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -40,13 +40,13 @@ import GHC.TypeLits
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.XArray (XArray)
import Data.Array.Mixed.XArray qualified as X
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Strided.Arith