From 36a89c7a26d73b0583ac765d29dadb1b918007f6 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 9 Apr 2024 18:35:35 +0200 Subject: Transpose functions in the API --- src/Data/Array/Mixed.hs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 2ad1d26..12c247f 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -251,13 +251,20 @@ rerank ssh ssh1 ssh2 f (XArray arr) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a +rerankTop :: forall sh sh1 sh2 a b. + (U.Unbox a, U.Unbox b) + => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + rerank2 :: forall sh sh1 sh2 a b c. (U.Unbox a, U.Unbox b, U.Unbox c) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -- cgit v1.2.3-70-g09d2