diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-09 18:35:35 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-09 18:35:35 +0200 |
commit | 36a89c7a26d73b0583ac765d29dadb1b918007f6 (patch) | |
tree | c497bc390a252a58ac21b14260fb655690cf75b8 | |
parent | ffa91484573a2c2be3f6ae2190c768e7a77e8b5c (diff) |
Transpose functions in the API
-rw-r--r-- | src/Data/Array/Mixed.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 15 | ||||
-rw-r--r-- | test/Main.hs | 3 |
4 files changed, 27 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 2ad1d26..12c247f 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -251,13 +251,20 @@ rerank ssh ssh1 ssh2 f (XArray arr) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a +rerankTop :: forall sh sh1 sh2 a b. + (U.Unbox a, U.Unbox b) + => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + rerank2 :: forall sh sh1 sh2 a b c. (U.Unbox a, U.Unbox b, U.Unbox c) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 983a636..aa7c7f9 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,6 +4,7 @@ module Data.Array.Nested ( Ranked, IxR(..), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, + rtranspose, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -12,6 +13,7 @@ module Data.Array.Nested ( IxS(..), KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, + stranspose, -- ** Lifting orthotope operations to 'Shaped' arrays slift, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index fa40921..d92c79c 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -315,6 +315,11 @@ mgenerate sh f checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' +mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a +mtranspose perm = + mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') + (X.transpose perm)) + -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'Nat'. @@ -578,6 +583,11 @@ rsumOuter1 (Ranked arr) . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) $ arr +rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a +rtranspose perm (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mtranspose perm arr) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -640,3 +650,8 @@ ssumOuter1 (Shaped arr) . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr + +stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a +stranspose perm (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mtranspose perm arr) diff --git a/test/Main.hs b/test/Main.hs index 156e0a5..8257ff0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -9,7 +9,7 @@ import Data.Array.Nested arr :: Ranked N2 (Shaped [N2, N3] (Double, Int)) arr = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) -> sgenerate @[N2, N3] (2 ::$ 3 ::$ IZS) $ \(k ::$ l ::$ IZS) -> - let s = i + j + k + l + let s = 24*i + 6*j + 3*k + l in (fromIntegral s, s) foo :: (Double, Int) @@ -19,3 +19,4 @@ main :: IO () main = do print arr print foo + print (rtranspose [1,0] arr) |