aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Permutation.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Permutation.hs')
-rw-r--r--src/Data/Array/Nested/Permutation.hs42
1 files changed, 30 insertions, 12 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index bed2877..065c9fd 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
@@ -24,6 +25,7 @@ import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
+import GHC.Exts (withDict)
import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
@@ -35,8 +37,8 @@ import Data.Array.Nested.Types
-- * Permutations
-- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
+-- dimension list is most similar to @backpermute@ in the @vector@ package; see
+-- 'Permute' for code that implements this.
data Perm list where
PNil :: Perm '[]
PCons :: SNat a -> Perm l -> Perm (a : l)
@@ -44,15 +46,22 @@ infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
+instance TestEquality Perm where
+ testEquality PNil PNil = Just Refl
+ testEquality (x `PCons` xs) (y `PCons` ys)
+ | Just Refl <- testEquality x y
+ , Just Refl <- testEquality xs ys = Just Refl
+ testEquality _ _ = Nothing
+
permRank :: Perm list -> SNat (Rank list)
permRank PNil = SNat
permRank (_ `PCons` l) | SNat <- permRank l = SNat
-permFromList :: [Int] -> (forall list. Perm list -> r) -> r
-permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r
+permFromListCont [] k = k PNil
+permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Nested.Permutation.permFromListCont: negative number in list: " ++ show x
permToList :: Perm list -> [Natural]
permToList PNil = mempty
@@ -118,6 +127,9 @@ class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r
+withKnownPerm = withDict @(KnownPerm l)
+
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -198,7 +210,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
@@ -223,7 +235,7 @@ permInverse = \perm k ->
++ " ; invperm = " ++ show invperm)
(permCheckPermutation invperm
(k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
+ (\ssh -> case permCheckInverse perm invperm ssh of
Just eq -> eq
Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
++ " ; invperm = " ++ show invperm)))
@@ -237,9 +249,9 @@ permInverse = \perm k ->
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ permCheckInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
+ permCheckInverse perm perminv ssh =
ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
@@ -263,7 +275,13 @@ lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
=> StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
lemRankDropLen ZKX PNil = Refl
-lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is)
+ | Refl <- lemRankDropLen sh is
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+ = Refl
+#else
+ = unsafeCoerceRefl
+#endif
lemRankDropLen (_ :!% _) PNil = Refl
lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"