diff options
Diffstat (limited to 'src/Data/Array/Nested/Permutation.hs')
-rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 031755f..03d1640 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -4,7 +4,6 @@ {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +24,7 @@ import Data.Proxy import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord +import GHC.Exts (withDict) import GHC.TypeError import GHC.TypeLits import GHC.TypeNats qualified as TN @@ -36,8 +36,8 @@ import Data.Array.Nested.Types -- * Permutations -- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. data Perm list where PNil :: Perm '[] PCons :: SNat a -> Perm l -> Perm (a : l) @@ -45,6 +45,13 @@ infixr 5 `PCons` deriving instance Show (Perm list) deriving instance Eq (Perm list) +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + permRank :: Perm list -> SNat (Rank list) permRank PNil = SNat permRank (_ `PCons` l) | SNat <- permRank l = SNat @@ -119,6 +126,9 @@ class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + -- | Untyped permutations for ranked arrays type PermR = [Int] @@ -224,7 +234,7 @@ permInverse = \perm k -> ++ " ; invperm = " ++ show invperm) (permCheckPermutation invperm (k invperm - (\ssh -> case provePermInverse perm invperm ssh of + (\ssh -> case permCheckInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm))) @@ -238,9 +248,9 @@ permInverse = \perm k -> toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - provePermInverse :: Perm is -> Perm is' -> StaticShX sh + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = + permCheckInverse perm perminv ssh = ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where |