summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index fa40921..d92c79c 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -315,6 +315,11 @@ mgenerate sh f
checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
+mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
+mtranspose perm =
+ mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
+ (X.transpose perm))
+
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'Nat'.
@@ -578,6 +583,11 @@ rsumOuter1 (Ranked arr)
. coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
$ arr
+rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose perm (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mtranspose perm arr)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -640,3 +650,8 @@ ssumOuter1 (Shaped arr)
. X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
$ arr
+
+stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
+stranspose perm (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mtranspose perm arr)