aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs11
1 files changed, 11 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index f98c860..ef91c7b 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -30,6 +30,7 @@ import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Constraint, Type)
+import Data.Proxy
import Data.Type.Equality
import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
@@ -394,6 +395,16 @@ shsHead (ShS shx) = case shxHead shx of
shsTail :: forall n sh. ShS (n : sh) -> ShS sh
shsTail = coerce (shxTail @_ @_ @Int)
+{-# INLINEABLE shsTakeIx #-}
+shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh
+shsTakeIx _ ZIS _ = ZSS
+shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh'
+
+{-# INLINEABLE shsDropIx #-}
+shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh'
+shsDropIx ZIS long = long
+shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long'
+
shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
shsInit =
gcastWith (unsafeCoerceRefl