aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-04-09 18:35:35 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-04-09 18:35:35 +0200
commit36a89c7a26d73b0583ac765d29dadb1b918007f6 (patch)
treec497bc390a252a58ac21b14260fb655690cf75b8
parentffa91484573a2c2be3f6ae2190c768e7a77e8b5c (diff)
Transpose functions in the API
-rw-r--r--src/Data/Array/Mixed.hs9
-rw-r--r--src/Data/Array/Nested.hs2
-rw-r--r--src/Data/Array/Nested/Internal.hs15
-rw-r--r--test/Main.hs3
4 files changed, 27 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 2ad1d26..12c247f 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -251,13 +251,20 @@ rerank ssh ssh1 ssh2 f (XArray arr)
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
= XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
(\a -> unXArray (f (XArray a)))
arr)
where
unXArray (XArray a) = a
+rerankTop :: forall sh sh1 sh2 a b.
+ (U.Unbox a, U.Unbox b)
+ => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh
+ -> (XArray sh1 a -> XArray sh2 b)
+ -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
+rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
+
rerank2 :: forall sh sh1 sh2 a b c.
(U.Unbox a, U.Unbox b, U.Unbox c)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 983a636..aa7c7f9 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -4,6 +4,7 @@ module Data.Array.Nested (
Ranked,
IxR(..),
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
+ rtranspose,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
@@ -12,6 +13,7 @@ module Data.Array.Nested (
IxS(..),
KnownShape(..), SShape(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
+ stranspose,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index fa40921..d92c79c 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -315,6 +315,11 @@ mgenerate sh f
checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
+mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
+mtranspose perm =
+ mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
+ (X.transpose perm))
+
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'Nat'.
@@ -578,6 +583,11 @@ rsumOuter1 (Ranked arr)
. coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
$ arr
+rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose perm (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mtranspose perm arr)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -640,3 +650,8 @@ ssumOuter1 (Shaped arr)
. X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
$ arr
+
+stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
+stranspose perm (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mtranspose perm arr)
diff --git a/test/Main.hs b/test/Main.hs
index 156e0a5..8257ff0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -9,7 +9,7 @@ import Data.Array.Nested
arr :: Ranked N2 (Shaped [N2, N3] (Double, Int))
arr = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) ->
sgenerate @[N2, N3] (2 ::$ 3 ::$ IZS) $ \(k ::$ l ::$ IZS) ->
- let s = i + j + k + l
+ let s = 24*i + 6*j + 3*k + l
in (fromIntegral s, s)
foo :: (Double, Int)
@@ -19,3 +19,4 @@ main :: IO ()
main = do
print arr
print foo
+ print (rtranspose [1,0] arr)