diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-09 18:35:35 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-09 18:35:35 +0200 | 
| commit | 36a89c7a26d73b0583ac765d29dadb1b918007f6 (patch) | |
| tree | c497bc390a252a58ac21b14260fb655690cf75b8 /src/Data/Array/Nested | |
| parent | ffa91484573a2c2be3f6ae2190c768e7a77e8b5c (diff) | |
Transpose functions in the API
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 15 | 
1 files changed, 15 insertions, 0 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index fa40921..d92c79c 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -315,6 +315,11 @@ mgenerate sh f      checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'      checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' +mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a +mtranspose perm = +  mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') +                            (X.transpose perm)) +  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is  -- represented on the type level as a 'Nat'. @@ -578,6 +583,11 @@ rsumOuter1 (Ranked arr)      . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)      $ arr +rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a +rtranspose perm (Ranked arr) +  | Dict <- lemKnownReplicate (Proxy @n) +  = Ranked (mtranspose perm arr) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -640,3 +650,8 @@ ssumOuter1 (Shaped arr)      . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))      . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)      $ arr + +stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a +stranspose perm (Shaped arr) +  | Dict <- lemKnownMapJust (Proxy @sh) +  = Shaped (mtranspose perm arr) | 
