diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 18 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 10 |
3 files changed, 20 insertions, 16 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index ffbc993..deb32b2 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -589,7 +589,7 @@ instance Elt a => Elt (Mixed sh' a) where {-# INLINEABLE mshape #-} mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + = shxTakeSh (Proxy @sh') sh (mshape arr) {-# INLINEABLE mindex #-} mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a @@ -617,10 +617,10 @@ instance Elt a => Elt (Mixed sh' a) where -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -635,10 +635,10 @@ instance Elt a => Elt (Mixed sh' a) where -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -654,10 +654,10 @@ instance Elt a => Elt (Mixed sh' a) where -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -690,7 +690,7 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result + in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr @@ -948,7 +948,7 @@ munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr +mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index b3f0c2f..abcf3f8 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -604,12 +604,16 @@ shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh' shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh +shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSh _ ZSX _ = ZSX +shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh + shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (listhDrop @i @()) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i -shxDropIx (IxX ZX) long = long -shxDropIx (IxX (_ ::% short)) long = case long of _ :$% long' -> shxDropIx (IxX short) long' +shxDropIx ZIX long = long +shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i shxDropSh = coerce (listhDrop @i @i) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 3f23478..42aed6e 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -88,7 +88,7 @@ cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (ssxFromShX sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" @@ -185,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -212,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -279,7 +279,7 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX ssh'F = ssxFromShX sh'F @@ -299,7 +299,7 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX in sumInner ssh' (ssxFromShX shF) $ transpose2 (ssxFromShX shF) ssh' $ |
