aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Shaped.hs58
1 files changed, 30 insertions, 28 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 4a3ed8d..e635f03 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -84,13 +84,16 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+
+ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
+ssumAllPrimP (Shaped arr) = msumAllPrimP arr
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -191,36 +194,35 @@ szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
sunzip = coerce munzip
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr))))
- (shxFromShS sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
+srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b))
+srerankPrimP sh2 f (Shaped (M_Shaped arr))
+ = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
--- | See the caveats at 'Data.Array.XArray.rerank'.
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+-- | See the caveats at 'mrerankPrim'.
+srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b)
+srerankPrim sh2 f (Shaped (M_Shaped arr)) =
+ Shaped (M_Shaped (mrerankPrim (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
sreplicate sh (Shaped arr)
| Refl <- lemMapJustApp sh (Proxy @sh')
= Shaped (mreplicate (shxFromShS sh) arr)
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)
+sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x)
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x)
sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =