From 5a802da40e5836ee19d46b9a2c771912dbff010e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 May 2024 17:36:47 +0200 Subject: applyPerm* functions --- src/Data/Array/Mixed.hs | 12 ++++- src/Data/Array/Nested/Internal.hs | 93 ++++++++++++++++++++++++++++++++++----- 2 files changed, 93 insertions(+), 12 deletions(-) diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 6766d90..4ae89a1 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -866,6 +866,15 @@ invertPermutation = \perm k -> provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh +applyPermX :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f +applyPermX perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +applyPermIxX :: forall i is sh. HList SNat is -> IxX sh i -> IxX (PermutePrefix is sh) i +applyPermIxX = coerce (applyPermX @(Const i)) + +applyPermShX :: forall i is sh. HList SNat is -> ShX sh i -> ShX (PermutePrefix is sh) i +applyPermShX = coerce (applyPermX @(SMayNat i SNat)) + class KnownNatList l where makeNatList :: HList SNat l instance KnownNatList '[] where makeNatList = HNil instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList @@ -881,8 +890,7 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm - = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] - in XArray (S.transpose perm' arr) + = XArray (S.transpose (permToList perm) arr) -- | The list argument gives indices into the original dimension list. -- diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f8d16aa..712c5f1 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -153,6 +153,9 @@ lemReplicatePlusApp sn _ _ = go sn lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl +lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerce Refl + lemDropLenApp :: X.Rank l1 <= X.Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) @@ -197,6 +200,19 @@ showListR f l = showString "[" . go "" l . showString "]" go _ ZR = id go prefix (x ::: xs) = showString prefix . f x . go "," xs +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + -- | An index into a rank-typed array. type role IxR nominal representational @@ -1497,6 +1513,36 @@ rsumOuter1 :: forall n a. (NumElt a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive +applyPermR :: forall i n. [Int] -> ListR n i -> ListR n i +applyPermR = \perm sh -> + listrFromList perm $ \sperm -> + case (snatFromListR sperm, snatFromListR sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TypeNats.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "applyPermR: Index in permutation out of range" + +applyPermIxR :: forall n i. [Int] -> IxR n i -> IxR n i +applyPermIxR = coerce (applyPermR @i) + +applyPermShR :: forall n i. [Int] -> ShR n i -> ShR n i +applyPermShR = coerce (applyPermR @i) + rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a rtranspose perm arr | sn@SNat <- snatFromShR (rshape arr) @@ -1835,21 +1881,48 @@ lemCommMapJustPermute (i `HCons` is) sh , Refl <- lemCommMapJustIndex i sh = Refl +listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend ZS idx' = idx' +listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' + +listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f +listsTakeLen HNil _ = ZS +listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh +listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape" + +listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLen HNil sh = sh +listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh +listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape" + +listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f +listsPermute HNil _ = ZS +listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) + +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f +listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest +listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest + | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listsIndex p pT i sh rest +listsIndex _ _ _ ZS _ = error "Index into empty shape" + shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) -shsTakeLen HNil _ = ZSS -shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh -shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" +shsTakeLen = coerce (listsTakeLen @SNat) shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) -shsPermute HNil _ = ZSS -shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh) +shsPermute = coerce (listsPermute @SNat) shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) -shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest -shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest - | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = shsIndex p pT i sh rest -shsIndex _ _ _ ZSS _ = error "Index into empty shape" +shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) + +applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) + +applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS = coerce (applyPermS @(Const i)) + +applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS = coerce (applyPermS @SNat) stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a -- cgit v1.2.3-70-g09d2