aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Permutation.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Permutation.hs')
-rw-r--r--src/Data/Array/Nested/Permutation.hs104
1 files changed, 73 insertions, 31 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 065c9fd..ecdb06d 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -18,7 +18,6 @@
module Data.Array.Nested.Permutation where
import Data.Coerce (coerce)
-import Data.Functor.Const
import Data.List (sort)
import Data.Maybe (fromMaybe)
import Data.Proxy
@@ -172,52 +171,95 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
-listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
-listxTakeLen PNil _ = ZX
-listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
-listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
-
-listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
-listxDropLen PNil sh = sh
-listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
-listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i
+listhTakeLen PNil _ = ZH
+listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh
+listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh
+listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape"
-listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
-listxPermute PNil _ = ZX
-listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
+listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i
+listhDropLen PNil sh = sh
+listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape"
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
-listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
-listxIndex _ _ _ ZX = error "Index into empty shape"
+listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i
+listhPermute PNil _ = ZH
+listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) =
+ case listhIndex i sh of
+ SUnknown x -> x `ConsUnknown` listhPermute is sh
+ SKnown x -> x `ConsKnown` listhPermute is sh
-listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
-listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh)
+listhIndex SZ (n `ConsUnknown` _) = SUnknown n
+listhIndex SZ (n `ConsKnown` _) = SKnown n
+listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh')
+ = listhIndex i sh
+listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh')
+ = listhIndex i sh
+listhIndex _ ZH = error "Index into empty shape"
-ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
-ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
+listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i
+listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh)
ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+ssxTakeLen = coerce (listhTakeLen @())
ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+ssxDropLen = coerce (listhDropLen @())
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+ssxPermute = coerce (listhPermute @())
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
+ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh)
+ssxIndex k = coerce (listhIndex @() k)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+ssxPermutePrefix = coerce (listhPermutePrefix @())
+
+shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh)
+shxTakeLen = coerce (listhTakeLen @Int)
+
+shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh)
+shxDropLen = coerce (listhDropLen @Int)
+
+shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh)
+shxPermute = coerce (listhPermute @Int)
+
+shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh)
+shxIndex k = coerce (listhIndex @i k)
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+shxPermutePrefix = coerce (listhPermutePrefix @Int)
+
+listxTakeLen :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i
+listxTakeLen PNil _ = ZX
+listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+
+listxDropLen :: forall i is sh. Perm is -> ListX sh i -> ListX (DropLen is sh) i
+listxDropLen PNil sh = sh
+listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+
+listxPermute :: forall i is sh. Perm is -> ListX sh i -> ListX (Permute is sh) i
+listxPermute PNil _ = ZX
+listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
+ listxIndex i sh ::% listxPermute is sh
+
+listxIndex :: forall j i sh. SNat i -> ListX sh j -> j
+listxIndex SZ (n ::% _) = n
+listxIndex (SS i) (_ ::% sh) = listxIndex i sh
+listxIndex _ ZX = error "Index into empty shape"
+
+listxPermutePrefix :: forall i is sh. Perm is -> ListX sh i -> ListX (PermutePrefix is sh) i
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix = coerce (listxPermutePrefix @i)
-- * Operations on permutations