aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs74
1 files changed, 53 insertions, 21 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index ededb60..2dc69ba 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -151,6 +151,19 @@ lemReplicatePlusApp sn _ _ = go sn
, Refl <- go n
= sym (X.lemReplicateSucc @a @(n'm1 + m))
+lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
+lemLeqPlus _ _ _ = Refl
+
+lemDropLenApp :: X.Rank l1 <= X.Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest)
+lemDropLenApp _ _ _ = unsafeCoerce Refl
+
+lemTakeLenApp :: X.Rank l1 <= X.Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest)
+lemTakeLenApp _ _ _ = unsafeCoerce Refl
+
-- === NEW INDEX TYPES === --
@@ -540,6 +553,9 @@ class Elt a where
mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+ mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
+ => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
+
-- ====== PRIVATE METHODS ====== --
mshapeTree :: a -> ShapeTree a
@@ -618,6 +634,10 @@ instance Storable a => Elt (Primitive a) where
let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1'
in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr)
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (X.shPermutePrefix perm sh)
+ (X.transpose (X.staticShapeFrom sh) perm arr)
+
mshapeTree _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
@@ -668,6 +688,8 @@ instance (Elt a, Elt b) => Elt (a, b) where
mcast ssh1 sh2 psh' (M_Tup2 a b) =
M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
@@ -755,6 +777,19 @@ instance Elt a => Elt (Mixed sh' a) where
= let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T
in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+ mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
+ => HList SNat is -> Mixed sh (Mixed sh' a)
+ -> Mixed (X.PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = X.shDropSh @sh @sh' (mshape arr) sh
+ , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh')
+ , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh'))
+ , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (X.shPermutePrefix perm sh)
+ (mtranspose perm arr)
+
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr)))))
@@ -825,13 +860,6 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
- => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
-mtranspose perm arr =
- let ssh = X.staticShapeFrom (mshape arr)
- sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh)
- in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr
-
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
@@ -1002,6 +1030,8 @@ instance Elt a => Elt (Ranked n a) where
mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -1098,6 +1128,8 @@ instance Elt a => Elt (Shaped sh a) where
mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -1499,21 +1531,21 @@ lemCommMapJustPermute (i `HCons` is) sh
, Refl <- lemCommMapJustIndex i sh
= Refl
-shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
-shTakeLen HNil _ = ZSS
-shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh
-shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
+shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shsTakeLen HNil _ = ZSS
+shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh
+shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
-shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
-shPermute HNil _ = ZSS
-shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh)
+shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shsPermute HNil _ = ZSS
+shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh)
-shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
-shIndex _ _ SZ (n :$$ _) rest = n :$$ rest
-shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest
+shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
| Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = shIndex p pT i sh rest
-shIndex _ _ _ ZSS _ = error "Index into empty shape"
+ = shsIndex p pT i sh rest
+shsIndex _ _ _ ZSS _ = error "Index into empty shape"
stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
=> HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
@@ -1521,8 +1553,8 @@ stranspose perm sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
, Refl <- lemCommMapJustTakeLen perm (sshape sarr)
, Refl <- lemCommMapJustDropLen perm (sshape sarr)
- , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr))
- , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
+ , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr))
+ , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a