diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 74 | 
1 files changed, 53 insertions, 21 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index ededb60..2dc69ba 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -151,6 +151,19 @@ lemReplicatePlusApp sn _ _ = go sn        , Refl <- go n        = sym (X.lemReplicateSucc @a @(n'm1 + m)) +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +lemDropLenApp :: X.Rank l1 <= X.Rank l2 +              => Proxy l1 -> Proxy l2 -> Proxy rest +              -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerce Refl + +lemTakeLenApp :: X.Rank l1 <= X.Rank l2 +              => Proxy l1 -> Proxy l2 -> Proxy rest +              -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerce Refl +  -- === NEW INDEX TYPES === -- @@ -540,6 +553,9 @@ class Elt a where    mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a +  mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) +             => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a +    -- ====== PRIVATE METHODS ====== --    mshapeTree :: a -> ShapeTree a @@ -618,6 +634,10 @@ instance Storable a => Elt (Primitive a) where      let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1'      in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) +  mtranspose perm (M_Primitive sh arr) = +    M_Primitive (X.shPermutePrefix perm sh) +                (X.transpose (X.staticShapeFrom sh) perm arr) +    mshapeTree _ = ()    mshapeTreeEq _ () () = True    mshapeTreeEmpty _ () = False @@ -668,6 +688,8 @@ instance (Elt a, Elt b) => Elt (a, b) where    mcast ssh1 sh2 psh' (M_Tup2 a b) =      M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) +  mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) +    mshapeTree (x, y) = (mshapeTree x, mshapeTree y)    mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'    mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 @@ -755,6 +777,19 @@ instance Elt a => Elt (Mixed sh' a) where      = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T        in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) +  mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) +             => HList SNat is -> Mixed sh (Mixed sh' a) +             -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) +  mtranspose perm (M_Nest sh arr) +    | let sh' = X.shDropSh @sh @sh' (mshape arr) sh +    , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') +    , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) +    , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') +    , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') +    , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') +    = M_Nest (X.shPermutePrefix perm sh) +             (mtranspose perm arr) +    mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)    mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) @@ -825,13 +860,6 @@ mgenerate sh f = case X.enumShape sh of                    mvecsWrite sh idx val vecs                  mvecsFreeze sh vecs -mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) -           => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a -mtranspose perm arr = -  let ssh = X.staticShapeFrom (mshape arr) -      sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh) -  in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr -  mappend :: forall n m sh a. Elt a          => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a  mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 @@ -1002,6 +1030,8 @@ instance Elt a => Elt (Ranked n a) where    mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) +  mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) +    mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -1098,6 +1128,8 @@ instance Elt a => Elt (Shaped sh a) where    mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) +  mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) +    mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -1499,21 +1531,21 @@ lemCommMapJustPermute (i `HCons` is) sh    , Refl <- lemCommMapJustIndex i sh    = Refl -shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) -shTakeLen HNil _ = ZSS -shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh -shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" +shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen HNil _ = ZSS +shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh +shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" -shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) -shPermute HNil _ = ZSS -shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh) +shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute HNil _ = ZSS +shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh) -shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) -shIndex _ _ SZ (n :$$ _) rest = n :$$ rest -shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest +shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest    | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = shIndex p pT i sh rest -shIndex _ _ _ ZSS _ = error "Index into empty shape" +  = shsIndex p pT i sh rest +shsIndex _ _ _ ZSS _ = error "Index into empty shape"  stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)             => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a @@ -1521,8 +1553,8 @@ stranspose perm sarr@(Shaped arr)    | Refl <- lemRankMapJust (sshape sarr)    , Refl <- lemCommMapJustTakeLen perm (sshape sarr)    , Refl <- lemCommMapJustDropLen perm (sshape sarr) -  , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr)) -  , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) +  , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) +  , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))    = Shaped (mtranspose perm arr)  sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a | 
