diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-29 12:36:03 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-29 12:40:02 +0200 |
commit | 260e00c3d661c21de5986ccf01d3292d3b8f7633 (patch) | |
tree | 4e29bed0adb2724ac48e61abf5376ba613ff1c2a /src/Data/Array/Nested/Mixed.hs | |
parent | 64404591661d3bc239804a1c17a25f81c434d852 (diff) |
Flip some index/shape-related functions
This ensures that the argument order consistently puts the main thing
being operated on at the end, and supporting singletons at the start.
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 652f1c6..a6e94b6 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -396,7 +396,7 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 @@ -438,7 +438,7 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -561,7 +561,7 @@ instance Elt a => Elt (Mixed sh' a) where Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -630,14 +630,14 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh + | let sh' = shxDropSh @sh @sh' sh (mshape arr) , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -827,8 +827,8 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + let sh1 = shxDropSSX ssh sh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) |