aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs12
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs10
-rw-r--r--src/Data/Array/Nested/Shaped.hs7
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs11
4 files changed, 31 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 4dd350a..b0276c9 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -610,13 +610,19 @@ shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSSX = coerce (listhDrop @i @())
+shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSh = coerce (listhDrop @i @i)
+
+{-# INLINEABLE shxTakeIx #-}
+shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeIx _ (IxX ZX) _ = ZSX
+shxTakeIx proxy (IxX (_ ::% long)) short = case short of i :$% short' -> i :$% shxTakeIx proxy (IxX long) short'
+
+{-# INLINEABLE shxDropIx #-}
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx ZIX long = long
shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long'
-shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSh = coerce (listhDrop @i @i)
-
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
shxInit = coerce (listhInit @i)
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index e44ab64..c5cdf6b 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -422,6 +422,16 @@ shrTail
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= coerce (shxTail @_ @_ @i)
+{-# INLINEABLE shrTakeIx #-}
+shrTakeIx :: forall n n' i j. Proxy n' -> IxR n j -> ShR (n + n') i -> ShR n i
+shrTakeIx _ ZIR _ = ZSR
+shrTakeIx p (_ :.: idx) sh = case sh of n :$: sh' -> n :$: shrTakeIx p idx sh'
+
+{-# INLINEABLE shrDropIx #-}
+shrDropIx :: forall n n' i j. IxR n j -> ShR (n + n') i -> ShR n' i
+shrDropIx ZIR long = long
+shrDropIx (_ :.: short) long = case long of _ :$: long' -> shrDropIx short long'
+
shrInit :: forall n i. ShR (n + 1) i -> ShR n i
shrInit
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 08711b6..d57106a 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -56,16 +56,11 @@ ssize = shsSize . sshape
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
-{-# INLINEABLE shsTakeIx #-}
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh
-shsTakeIx _ _ ZIS = ZSS
-shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
-
{-# INLINEABLE sindexPartial #-}
sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial sarr@(Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
+ (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) idx (sshape sarr)) (Proxy @sh2))) arr)
(ixxFromIxS idx))
-- | __WARNING__: All values returned from the function must have equal shape.
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index f98c860..ef91c7b 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -30,6 +30,7 @@ import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Constraint, Type)
+import Data.Proxy
import Data.Type.Equality
import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
@@ -394,6 +395,16 @@ shsHead (ShS shx) = case shxHead shx of
shsTail :: forall n sh. ShS (n : sh) -> ShS sh
shsTail = coerce (shxTail @_ @_ @Int)
+{-# INLINEABLE shsTakeIx #-}
+shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh
+shsTakeIx _ ZIS _ = ZSS
+shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh'
+
+{-# INLINEABLE shsDropIx #-}
+shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh'
+shsDropIx ZIS long = long
+shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long'
+
shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
shsInit =
gcastWith (unsafeCoerceRefl