From 0f13e61b9eecb0a14e4c62e78218d1652d9d4cf2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 11:27:18 +0200 Subject: mtranspose without rerank --- src/Data/Array/Mixed.hs | 99 +++++++++++++++++++++++++++------------ src/Data/Array/Nested/Internal.hs | 74 ++++++++++++++++++++--------- 2 files changed, 122 insertions(+), 51 deletions(-) diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 6ac3ab3..f62d781 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -303,27 +303,33 @@ completeShXzeros ZKX = ZSX completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh --- TODO: generalise all these things to arbitrary @i@ -ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') -ixAppend ZIX idx' = idx' -ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx' +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixAppend = coerce (listxAppend @_ @(Const i)) + +shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shAppend = coerce (listxAppend @_ @(SMayNat i SNat)) -shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') -shAppend ZSX sh' = sh' -shAppend (n :$% sh) sh' = n :$% shAppend sh sh' +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short -ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' -ixDrop long ZIX = long -ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short +ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixDrop = coerce (listxDrop @(Const i) @(Const i)) -shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' -shDropIx sh ZIX = sh -shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx +shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) -ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh' -ssxDropIx ssh ZIX = ssh -ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx +shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +-- TODO: generalise all these things to arbitrary @i@ shTail :: IShX (n : sh) -> IShX sh shTail (_ :$% sh) = sh @@ -603,26 +609,59 @@ lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl -ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen HNil _ = ZKX -ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh -ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape" +listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen HNil _ = ZX +listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen HNil sh = sh +listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f +listxPermute HNil _ = ZX +listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f +listxIndex _ _ SZ (n ::% _) rest = n ::% rest +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh rest +listxIndex _ _ _ ZX _ = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen HNil sh = sh -ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh -ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape" +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute HNil _ = ZKX -ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex _ _ SZ (n :!% _) rest = n :!% rest -ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = ssxIndex p pT i sh rest -ssxIndex _ _ _ ZKX _ = error "Index into empty shape" +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shTakeLen :: forall is sh. HList SNat is -> IShX sh -> IShX (TakeLen is sh) +shTakeLen = coerce (listxTakeLen @(SMayNat Int SNat)) + +shDropLen :: HList SNat is -> IShX sh -> IShX (DropLen is sh) +shDropLen = coerce (listxDropLen @(SMayNat Int SNat)) + +shPermute :: HList SNat is -> IShX sh -> IShX (Permute is sh) +shPermute = coerce (listxPermute @(SMayNat Int SNat)) + +shIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> IShX (Permute is shT) -> IShX (Index i sh : Permute is shT) +shIndex p1 p2 = coerce (listxIndex @(SMayNat Int SNat) p1 p2) + +shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) +shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index ededb60..2dc69ba 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -151,6 +151,19 @@ lemReplicatePlusApp sn _ _ = go sn , Refl <- go n = sym (X.lemReplicateSucc @a @(n'm1 + m)) +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +lemDropLenApp :: X.Rank l1 <= X.Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerce Refl + +lemTakeLenApp :: X.Rank l1 <= X.Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerce Refl + -- === NEW INDEX TYPES === -- @@ -540,6 +553,9 @@ class Elt a where mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) + => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a + -- ====== PRIVATE METHODS ====== -- mshapeTree :: a -> ShapeTree a @@ -618,6 +634,10 @@ instance Storable a => Elt (Primitive a) where let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) + mtranspose perm (M_Primitive sh arr) = + M_Primitive (X.shPermutePrefix perm sh) + (X.transpose (X.staticShapeFrom sh) perm arr) + mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False @@ -668,6 +688,8 @@ instance (Elt a, Elt b) => Elt (a, b) where mcast ssh1 sh2 psh' (M_Tup2 a b) = M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 @@ -755,6 +777,19 @@ instance Elt a => Elt (Mixed sh' a) where = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) + => HList SNat is -> Mixed sh (Mixed sh' a) + -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = X.shDropSh @sh @sh' (mshape arr) sh + , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') + , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) + , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (X.shPermutePrefix perm sh) + (mtranspose perm arr) + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) @@ -825,13 +860,6 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) - => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a -mtranspose perm arr = - let ssh = X.staticShapeFrom (mshape arr) - sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh) - in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr - mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 @@ -1002,6 +1030,8 @@ instance Elt a => Elt (Ranked n a) where mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -1098,6 +1128,8 @@ instance Elt a => Elt (Shaped sh a) where mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -1499,21 +1531,21 @@ lemCommMapJustPermute (i `HCons` is) sh , Refl <- lemCommMapJustIndex i sh = Refl -shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) -shTakeLen HNil _ = ZSS -shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh -shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" +shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen HNil _ = ZSS +shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh +shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" -shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) -shPermute HNil _ = ZSS -shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh) +shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute HNil _ = ZSS +shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh) -shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) -shIndex _ _ SZ (n :$$ _) rest = n :$$ rest -shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest +shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = shIndex p pT i sh rest -shIndex _ _ _ ZSS _ = error "Index into empty shape" + = shsIndex p pT i sh rest +shsIndex _ _ _ ZSS _ = error "Index into empty shape" stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a @@ -1521,8 +1553,8 @@ stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemCommMapJustTakeLen perm (sshape sarr) , Refl <- lemCommMapJustDropLen perm (sshape sarr) - , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr)) - , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) + , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -- cgit v1.2.3-70-g09d2