aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-28 17:36:47 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-28 17:36:47 +0200
commit5a802da40e5836ee19d46b9a2c771912dbff010e (patch)
tree9c794b27f6c861335e007a68dbb559776d1ffaaa /src/Data/Array
parent6b74bff29ea3c21adaeea12921aed057b5858040 (diff)
applyPerm* functions
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs12
-rw-r--r--src/Data/Array/Nested/Internal.hs93
2 files changed, 93 insertions, 12 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 6766d90..4ae89a1 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -866,6 +866,15 @@ invertPermutation = \perm k ->
provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh
+applyPermX :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f
+applyPermX perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+applyPermIxX :: forall i is sh. HList SNat is -> IxX sh i -> IxX (PermutePrefix is sh) i
+applyPermIxX = coerce (applyPermX @(Const i))
+
+applyPermShX :: forall i is sh. HList SNat is -> ShX sh i -> ShX (PermutePrefix is sh) i
+applyPermShX = coerce (applyPermX @(SMayNat i SNat))
+
class KnownNatList l where makeNatList :: HList SNat l
instance KnownNatList '[] where makeNatList = HNil
instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList
@@ -881,8 +890,7 @@ transpose ssh perm (XArray arr)
, Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
, Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
, Refl <- lemRankDropLen ssh perm
- = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
- in XArray (S.transpose perm' arr)
+ = XArray (S.transpose (permToList perm) arr)
-- | The list argument gives indices into the original dimension list.
--
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index f8d16aa..712c5f1 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -153,6 +153,9 @@ lemReplicatePlusApp sn _ _ = go sn
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
+lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
+lemLeqSuccSucc _ _ = unsafeCoerce Refl
+
lemDropLenApp :: X.Rank l1 <= X.Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
-> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest)
@@ -197,6 +200,19 @@ showListR f l = showString "[" . go "" l . showString "]"
go _ ZR = id
go prefix (x ::: xs) = showString prefix . f x . go "," xs
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
-- | An index into a rank-typed array.
type role IxR nominal representational
@@ -1497,6 +1513,36 @@ rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+applyPermR :: forall i n. [Int] -> ListR n i -> ListR n i
+applyPermR = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (snatFromListR sperm, snatFromListR sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TypeNats.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "applyPermR: Index in permutation out of range"
+
+applyPermIxR :: forall n i. [Int] -> IxR n i -> IxR n i
+applyPermIxR = coerce (applyPermR @i)
+
+applyPermShR :: forall n i. [Int] -> ShR n i -> ShR n i
+applyPermShR = coerce (applyPermR @i)
+
rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a
rtranspose perm arr
| sn@SNat <- snatFromShR (rshape arr)
@@ -1835,21 +1881,48 @@ lemCommMapJustPermute (i `HCons` is) sh
, Refl <- lemCommMapJustIndex i sh
= Refl
+listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend ZS idx' = idx'
+listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+
+listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f
+listsTakeLen HNil _ = ZS
+listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh
+listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape"
+
+listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLen HNil sh = sh
+listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh
+listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape"
+
+listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f
+listsPermute HNil _ = ZS
+listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh)
+
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f
+listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest
+listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest
+ | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listsIndex p pT i sh rest
+listsIndex _ _ _ ZS _ = error "Index into empty shape"
+
shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
-shsTakeLen HNil _ = ZSS
-shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh
-shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
+shsTakeLen = coerce (listsTakeLen @SNat)
shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
-shsPermute HNil _ = ZSS
-shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh)
+shsPermute = coerce (listsPermute @SNat)
shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
-shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest
-shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
- | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = shsIndex p pT i sh rest
-shsIndex _ _ _ ZSS _ = error "Index into empty shape"
+shsIndex pis pshT = coerce (listsIndex @SNat pis pshT)
+
+applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f
+applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh)
+
+applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i
+applyPermIxS = coerce (applyPermS @(Const i))
+
+applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh)
+applyPermShS = coerce (applyPermS @SNat)
stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
=> HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a