aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Mixed.hs18
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs8
2 files changed, 15 insertions, 11 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index ffbc993..deb32b2 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -589,7 +589,7 @@ instance Elt a => Elt (Mixed sh' a) where
{-# INLINEABLE mshape #-}
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
+ = shxTakeSh (Proxy @sh') sh (mshape arr)
{-# INLINEABLE mindex #-}
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
@@ -617,10 +617,10 @@ instance Elt a => Elt (Mixed sh' a) where
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift ssh2 f (M_Nest sh1 arr) =
let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -635,10 +635,10 @@ instance Elt a => Elt (Mixed sh' a) where
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -654,10 +654,10 @@ instance Elt a => Elt (Mixed sh' a) where
-> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result))
in fmap (M_Nest sh2) result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
f' sshT
@@ -690,7 +690,7 @@ instance Elt a => Elt (Mixed sh' a) where
mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
mconcat l@(M_Nest sh1 _ :| _) =
let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result
+ in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result
mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
@@ -948,7 +948,7 @@ munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
-mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index b3f0c2f..abcf3f8 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -604,12 +604,16 @@ shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh'
shxTakeSSX _ ZKX _ = ZSX
shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
+shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSh _ ZSX _ = ZSX
+shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh
+
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSSX = coerce (listhDrop @i @())
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropIx (IxX ZX) long = long
-shxDropIx (IxX (_ ::% short)) long = case long of _ :$% long' -> shxDropIx (IxX short) long'
+shxDropIx ZIX long = long
+shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long'
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSh = coerce (listhDrop @i @i)