aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs1
-rw-r--r--src/Data/Array/Nested/Convert.hs3
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs282
3 files changed, 104 insertions, 182 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index ec81843..9922644 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -3,7 +3,6 @@
module Data.Array.Nested (
-- * Ranked arrays
Ranked(Ranked),
- ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim,
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 2595c64..c6f23ae 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -16,7 +16,7 @@ module Data.Array.Nested.Convert (
-- * Shape\/index\/list casting functions
-- ** To ranked
ixrFromIxS, ixrFromIxS', ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX,
- listrCast, ixrCast, shrCast,
+ ixrCast, shrCast,
-- ** To shaped
ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
ixsCast,
@@ -86,7 +86,6 @@ shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx
shrFromShX :: IShX (Replicate n Nothing) -> IShR n
shrFromShX = coerce
--- listrCast re-exported
-- ixrCast re-exported
-- shrCast re-exported
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index a352eb3..5e84a2d 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -44,170 +44,40 @@ import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
--- * Ranked lists
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-newtype ListR n i = ListR (ListX (Replicate n Nothing) i)
- deriving (Eq, Ord, NFData, Functor, Foldable)
-
-pattern ZR :: forall n i. () => n ~ 0 => ListR n i
-pattern ZR <- ListR (matchZX @n -> Just Refl)
- where ZR = ListR ZX
-
-matchZX :: forall n i. ListX (Replicate n Nothing) i -> Maybe (n :~: 0)
-matchZX ZX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
-matchZX _ = Nothing
-
-pattern (:::)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ListR n i -> ListR n1 i
-pattern i ::: l <- (listrUncons -> Just (UnconsListRRes i l))
- where i ::: ListR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ListR (i ::% l)
-infixr 3 :::
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes i (ListR n i)
-listrUncons :: forall n1 i. ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (ListR ((::%) @n @sh i l))
- | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl
- , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl
- , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl =
- Just (UnconsListRRes i (ListR @(Rank sh) l))
-listrUncons (ListR _) = Nothing
-
-{-# COMPLETE ZR, (:::) #-}
-
-#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance Show i => Show (ListR n i)
-#else
-instance Show i => Show (ListR n i) where
- showsPrec _ = listrShow shows
-#endif
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqRank ZR ZR = Just Refl
-listrEqRank (_ ::: sh) (_ ::: sh')
- | Just Refl <- listrEqRank sh sh'
- = Just Refl
-listrEqRank _ _ = Nothing
-
--- | This compares the lists for value equality.
-listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqual ZR ZR = Just Refl
-listrEqual (i ::: sh) (j ::: sh')
- | Just Refl <- listrEqual sh sh'
- , i == j
- = Just Refl
-listrEqual _ _ = Nothing
-
-{-# INLINE listrShow #-}
-listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
-listrShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR n' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrRank :: ListR n i -> SNat n
-listrRank ZR = SNat
-listrRank (_ ::: sh) = snatSucc (listrRank sh)
-
--- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked)
-listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $
- coerce (listxAppend @_ @_ @i)
-
-{-# INLINE listrFromList #-}
-listrFromList :: SNat n -> [i] -> ListR n i
-listrFromList topsn topl = assert (fromSNat' topsn == length topl)
- $ ListR $ IsList.fromList topl
-
-listrHead :: ListR (n + 1) i -> i
-listrHead (i ::: _) = i
-
-listrTail :: ListR (n + 1) i -> ListR n i
-listrTail (_ ::: sh) = sh
-
-listrInit :: ListR (n + 1) i -> ListR n i
-listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
-listrInit (_ ::: ZR) = ZR
-
-listrLast :: ListR (n + 1) i -> i
-listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
-listrLast (n ::: ZR) = n
-
--- | Performs a runtime check that the lengths are identical.
-listrCast :: SNat n' -> ListR n i -> ListR n' i
-listrCast = listrCastWithName "listrCast"
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
-listrZip ZR ZR = ZR
-listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
-listrZip _ _ = error "listrZip: impossible pattern needlessly required"
-
-{-# INLINE listrZipWith #-}
-listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
-listrZipWith _ ZR ZR = ZR
-listrZipWith f (i ::: irest) (j ::: jrest) =
- f i j ::: listrZipWith f irest jrest
-listrZipWith _ _ _ =
- error "listrZipWith: impossible pattern needlessly required"
-
-listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
-listrSplitAt SZ sh = (ZR, sh)
-listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
-listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
-listrPermutePrefix :: forall n i. PermR -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
- case listrRank sh of { shlen@SNat ->
- let sperm = listrFromList permlen perm in
- case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- }
- where
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "listrPermutePrefix: Index in permutation out of range"
-
-
-- * Ranked indices
-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
+newtype IxR n i = IxR (IxX (Replicate n Nothing) i)
deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
+pattern ZIR <- IxR (matchZIX @n -> Just Refl)
+ where ZIR = IxR ZIX
+
+matchZIX :: forall n i. IxX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZIX ZIX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZIX _ = Nothing
pattern (:.:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> IxR n i -> IxR n1 i
-pattern i :.: l <- IxR (i ::: (IxR -> l))
- where i :.: IxR l = IxR (i ::: l)
+pattern i :.: l <- (ixrUncons -> Just (UnconsIxRRes i l))
+ where i :.: IxR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = IxR (i :.% l)
infixr 3 :.:
+data UnconsIxRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsIxRRes i (IxR n i)
+ixrUncons :: forall n1 i. IxR n1 i -> Maybe (UnconsIxRRes i n1)
+ixrUncons (IxR ((:.%) @n @sh i l))
+ | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl
+ , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl =
+ Just (UnconsIxRRes i (IxR @(Rank sh) l))
+ixrUncons (IxR _) = Nothing
+
{-# COMPLETE ZIR, (:.:) #-}
-- For convenience, this contains regular 'Int's instead of bounded integers
@@ -218,48 +88,116 @@ type IIxR n = IxR n Int
deriving instance Show i => Show (IxR n i)
#else
instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = listrShow shows l
+ showsPrec _ = ixrShow shows
#endif
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+ixrEqRank :: IxR n i -> IxR n' i -> Maybe (n :~: n')
+ixrEqRank ZIR ZIR = Just Refl
+ixrEqRank (_ :.: sh) (_ :.: sh')
+ | Just Refl <- ixrEqRank sh sh'
+ = Just Refl
+ixrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+ixrEqual :: Eq i => IxR n i -> IxR n' i -> Maybe (n :~: n')
+ixrEqual ZIR ZIR = Just Refl
+ixrEqual (i :.: sh) (j :.: sh')
+ | Just Refl <- ixrEqual sh sh'
+ , i == j
+ = Just Refl
+ixrEqual _ _ = Nothing
+
+{-# INLINE ixrShow #-}
+ixrShow :: forall n i. (i -> ShowS) -> IxR n i -> ShowS
+ixrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> IxR n' i -> ShowS
+ go _ ZIR = id
+ go prefix (x :.: xs) = showString prefix . f x . go "," xs
+
ixrRank :: IxR n i -> SNat n
-ixrRank (IxR sh) = listrRank sh
+ixrRank ZIR = SNat
+ixrRank (_ :.: sh) = snatSucc (ixrRank sh)
ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
-{-# INLINEABLE ixrFromList #-}
-ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
-ixrFromList = coerce (listrFromList @_ @i)
+{-# INLINE ixrFromList #-}
+ixrFromList :: SNat n -> [i] -> IxR n i
+ixrFromList topsn topl = assert (fromSNat' topsn == length topl)
+ $ IxR $ IsList.fromList topl
ixrHead :: IxR (n + 1) i -> i
-ixrHead (IxR list) = listrHead list
+ixrHead (i :.: _) = i
ixrTail :: IxR (n + 1) i -> IxR n i
-ixrTail (IxR list) = IxR (listrTail list)
+ixrTail (_ :.: sh) = sh
ixrInit :: IxR (n + 1) i -> IxR n i
-ixrInit (IxR list) = IxR (listrInit list)
+ixrInit (n :.: sh@(_ :.: _)) = n :.: ixrInit sh
+ixrInit (_ :.: ZIR) = ZIR
ixrLast :: IxR (n + 1) i -> i
-ixrLast (IxR list) = listrLast list
+ixrLast (_ :.: sh@(_ :.: _)) = ixrLast sh
+ixrLast (n :.: ZIR) = n
-- | Performs a runtime check that the lengths are identical.
ixrCast :: SNat n' -> IxR n i -> IxR n' i
-ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
+ixrCast SZ ZIR = ZIR
+ixrCast (SS n) (i :.: l) = i :.: ixrCast n l
+ixrCast _ _ = error "ixrCast: ranks don't match"
+-- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked)
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @n @m @i)
+ixrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $
+ coerce (listxAppend @_ @_ @i)
+
+ixrIndex :: forall k n i. (k + 1 <= n) => SNat k -> IxR n i -> i
+ixrIndex SZ (x :.: _) = x
+ixrIndex (SS i) (_ :.: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = ixrIndex i xs
+ixrIndex _ ZIR = error "k + 1 <= 0"
ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
-ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+ixrZip ZIR ZIR = ZIR
+ixrZip (i :.: irest) (j :.: jrest) = (i, j) :.: ixrZip irest jrest
+ixrZip _ _ = error "ixrZip: impossible pattern needlessly required"
{-# INLINE ixrZipWith #-}
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
-ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+ixrZipWith _ ZIR ZIR = ZIR
+ixrZipWith f (i :.: irest) (j :.: jrest) =
+ f i j :.: ixrZipWith f irest jrest
+ixrZipWith _ _ _ =
+ error "ixrZipWith: impossible pattern needlessly required"
+
+ixrSplitAt :: m <= n' => SNat m -> IxR n' i -> (IxR m i, IxR (n' - m) i)
+ixrSplitAt SZ sh = (ZIR, sh)
+ixrSplitAt (SS m) (n :.: sh) = (\(pre, post) -> (n :.: pre, post)) (ixrSplitAt m sh)
+ixrSplitAt SS{} ZIR = error "m' + 1 <= 0"
ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @n @i)
+ixrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case ixrRank sh of { shlen@SNat ->
+ let sperm = ixrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
+ where
+ applyPermRFull :: SNat m -> IxR k Int -> IxR m i -> IxR k i
+ applyPermRFull _ ZIR _ = ZIR
+ applyPermRFull sm@SNat (i :.: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> ixrIndex si l :.: applyPermRFull sm perm l
+ EQI -> ixrIndex si l :.: applyPermRFull sm perm l
+ GTI -> error "ixrPermutePrefix: Index in permutation out of range"
-- | Given a multidimensional index, get the corresponding linear
-- index into the buffer.
@@ -451,7 +389,7 @@ shrIndex k (ShR sh) = case shxIndex @i k sh of
SUnknown i -> i
SKnown{} -> error "shrIndex: impossible SKnown"
--- Copy-pasted from listrPermutePrefix, probably unavoidably.
+-- Copy-pasted from ixrPermutePrefix, probably unavoidably.
shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i
shrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
@@ -475,15 +413,9 @@ shrPermutePrefix = \perm sh ->
-- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList = listrFromList (SNat @n)
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (IxR n i) where
type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
+ fromList = ixrFromList (SNat @n)
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
@@ -491,11 +423,3 @@ instance KnownNat n => IsList (IShR n) where
type Item (IShR n) = Int
fromList = shrFromList (SNat @n)
toList = shrToList
-
-
--- * Internal helper functions
-
-listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
-listrCastWithName _ SZ ZR = ZR
-listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l
-listrCastWithName name _ _ = error $ name ++ ": ranks don't match"