aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs4
-rw-r--r--src/Data/Array/Nested/Internal.hs6
2 files changed, 6 insertions, 4 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 856e6cb..5fbc46f 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -418,6 +418,8 @@ type family Permute is sh where
Permute '[] sh = '[]
Permute (i : is) sh = Index i sh : Permute is sh
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
data HList f list where
HNil :: HList f '[]
HCons :: f a -> HList f l -> HList f (a : l)
@@ -486,7 +488,7 @@ ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
=> HList SNat is
-> XArray sh a
- -> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a
+ -> XArray (PermutePrefix is sh) a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh))
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 65c5419..27e1a30 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -552,11 +552,11 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a
+mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
mtranspose perm
| Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh)))
= mlift $ \(Proxy @sh') ->
- X.rerankTop (knownShapeX @sh) (knownShapeX @(X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh)) (knownShapeX @sh')
+ X.rerankTop (knownShapeX @sh) (knownShapeX @(X.PermutePrefix is sh)) (knownShapeX @sh')
(X.transpose perm)
mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
@@ -1378,7 +1378,7 @@ shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
= shIndex p pT i sh rest
shIndex _ _ _ ZSS _ = error "Index into empty shape"
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
, Refl <- lemRankMapJust (Proxy @sh)