From 0f13e61b9eecb0a14e4c62e78218d1652d9d4cf2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 11:27:18 +0200 Subject: mtranspose without rerank --- src/Data/Array/Nested/Internal.hs | 74 ++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 21 deletions(-) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index ededb60..2dc69ba 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -151,6 +151,19 @@ lemReplicatePlusApp sn _ _ = go sn , Refl <- go n = sym (X.lemReplicateSucc @a @(n'm1 + m)) +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +lemDropLenApp :: X.Rank l1 <= X.Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerce Refl + +lemTakeLenApp :: X.Rank l1 <= X.Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerce Refl + -- === NEW INDEX TYPES === -- @@ -540,6 +553,9 @@ class Elt a where mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) + => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a + -- ====== PRIVATE METHODS ====== -- mshapeTree :: a -> ShapeTree a @@ -618,6 +634,10 @@ instance Storable a => Elt (Primitive a) where let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) + mtranspose perm (M_Primitive sh arr) = + M_Primitive (X.shPermutePrefix perm sh) + (X.transpose (X.staticShapeFrom sh) perm arr) + mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False @@ -668,6 +688,8 @@ instance (Elt a, Elt b) => Elt (a, b) where mcast ssh1 sh2 psh' (M_Tup2 a b) = M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 @@ -755,6 +777,19 @@ instance Elt a => Elt (Mixed sh' a) where = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) + => HList SNat is -> Mixed sh (Mixed sh' a) + -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = X.shDropSh @sh @sh' (mshape arr) sh + , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') + , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) + , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (X.shPermutePrefix perm sh) + (mtranspose perm arr) + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) @@ -825,13 +860,6 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) - => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a -mtranspose perm arr = - let ssh = X.staticShapeFrom (mshape arr) - sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh) - in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr - mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 @@ -1002,6 +1030,8 @@ instance Elt a => Elt (Ranked n a) where mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -1098,6 +1128,8 @@ instance Elt a => Elt (Shaped sh a) where mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -1499,21 +1531,21 @@ lemCommMapJustPermute (i `HCons` is) sh , Refl <- lemCommMapJustIndex i sh = Refl -shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) -shTakeLen HNil _ = ZSS -shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh -shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" +shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen HNil _ = ZSS +shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh +shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" -shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) -shPermute HNil _ = ZSS -shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh) +shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute HNil _ = ZSS +shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh) -shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) -shIndex _ _ SZ (n :$$ _) rest = n :$$ rest -shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest +shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = shIndex p pT i sh rest -shIndex _ _ _ ZSS _ = error "Index into empty shape" + = shsIndex p pT i sh rest +shsIndex _ _ _ ZSS _ = error "Index into empty shape" stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a @@ -1521,8 +1553,8 @@ stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemCommMapJustTakeLen perm (sshape sarr) , Refl <- lemCommMapJustDropLen perm (sshape sarr) - , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr)) - , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) + , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -- cgit v1.2.3-70-g09d2