diff options
Diffstat (limited to 'src/Data/Array/Nested/Permutation.hs')
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 104 |
1 files changed, 73 insertions, 31 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 065c9fd..ecdb06d 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -18,7 +18,6 @@ module Data.Array.Nested.Permutation where import Data.Coerce (coerce) -import Data.Functor.Const import Data.List (sort) import Data.Maybe (fromMaybe) import Data.Proxy @@ -172,52 +171,95 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen PNil _ = ZX -listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen PNil sh = sh -listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" +listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i +listhTakeLen PNil _ = ZH +listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh +listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh +listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape" -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh +listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i +listhDropLen PNil sh = sh +listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape" -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" +listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i +listhPermute PNil _ = ZH +listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) = + case listhIndex i sh of + SUnknown x -> x `ConsUnknown` listhPermute is sh + SKnown x -> x `ConsKnown` listhPermute is sh -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) +listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh) +listhIndex SZ (n `ConsUnknown` _) = SUnknown n +listhIndex SZ (n `ConsKnown` _) = SKnown n +listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh') + = listhIndex i sh +listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh') + = listhIndex i sh +listhIndex _ ZH = error "Index into empty shape" -ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) +listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i +listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) +ssxTakeLen = coerce (listhTakeLen @()) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) +ssxDropLen = coerce (listhDropLen @()) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) +ssxPermute = coerce (listhPermute @()) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) +ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh) +ssxIndex k = coerce (listhIndex @() k) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) +ssxPermutePrefix = coerce (listhPermutePrefix @()) + +shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) +shxTakeLen = coerce (listhTakeLen @Int) + +shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) +shxDropLen = coerce (listhDropLen @Int) + +shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) +shxPermute = coerce (listhPermute @Int) + +shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh) +shxIndex k = coerce (listhIndex @i k) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) +shxPermutePrefix = coerce (listhPermutePrefix @Int) + +listxTakeLen :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall i is sh. Perm is -> ListX sh i -> ListX (DropLen is sh) i +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall i is sh. Perm is -> ListX sh i -> ListX (Permute is sh) i +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex i sh ::% listxPermute is sh + +listxIndex :: forall j i sh. SNat i -> ListX sh j -> j +listxIndex SZ (n ::% _) = n +listxIndex (SS i) (_ ::% sh) = listxIndex i sh +listxIndex _ ZX = error "Index into empty shape" + +listxPermutePrefix :: forall i is sh. Perm is -> ListX sh i -> ListX (PermutePrefix is sh) i +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @i) -- * Operations on permutations |
