aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Permutation.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Permutation.hs')
-rw-r--r--src/Data/Array/Mixed/Permutation.hs88
1 files changed, 39 insertions, 49 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index 331d5e0..e85e67f 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -13,8 +13,6 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Permutation where
import Data.Coerce (coerce)
@@ -24,13 +22,12 @@ import Data.Maybe (fromMaybe)
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
-import Data.Type.Ord
import GHC.TypeError
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
+import Numeric.Natural
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
+import Data.SNat.Peano
-- * Permutations
@@ -46,18 +43,19 @@ deriving instance Show (Perm list)
deriving instance Eq (Perm list)
permRank :: Perm list -> SNat (Rank list)
-permRank PNil = SNat
-permRank (_ `PCons` l) | SNat <- permRank l = SNat
+permRank PNil = SZ
+permRank (_ `PCons` l) = SS (permRank l)
permFromList :: [Int] -> (forall list. Perm list -> r) -> r
permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+permFromList (x : xs) k =
+ withSomeSNat' x $ \sn ->
+ permFromList xs $ \list ->
+ k (sn `PCons` list)
permToList :: Perm list -> [Natural]
permToList PNil = mempty
-permToList (x `PCons` l) = TN.fromSNat x : permToList l
+permToList (x `PCons` l) = fromSNat x : permToList l
permToList' :: Perm list -> [Int]
permToList' = map fromIntegral . permToList
@@ -68,48 +66,47 @@ permToList' = map fromIntegral . permToList
permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r
permCheckPermutation = \p k ->
let n = permRank p
- in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
+ in case (provePerm1 (Proxy @list) n p, provePerm2 SZ n p) of
(Just Refl, Just Refl) -> Just k
_ -> Nothing
where
- lemElemCount :: (0 <= n, Compare n m ~ LT)
- => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount :: (Z <= n, n < m)
+ => proxy n -> proxy m -> Elem n (Count Z m) :~: True
lemElemCount _ _ = unsafeCoerceRefl
- lemCount :: (OrdCond (Compare i n) True False True ~ True)
- => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount :: i < n => proxy i -> proxy n -> Count i n :~: i : Count (S i) n
lemCount _ _ = unsafeCoerceRefl
lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
lemElem _ _ = unsafeCoerceRefl
provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
- -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ -> Maybe (AllElem' is' (Count Z (Rank isTop)) :~: True)
provePerm1 _ _ PNil = Just (Refl)
- provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ provePerm1 p rtop (PCons sn perm)
| Just Refl <- provePerm1 p rtop perm
- = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
- (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ = case (snatCompare SZ sn, snatCompare sn rtop) of
+ (SLT, SLT) | Refl <- lemElemCount sn rtop -> Just Refl
+ (SEQ, SLT) | Refl <- lemElemCount sn rtop -> Just Refl
_ -> Nothing
| otherwise
= Nothing
provePerm2 :: SNat i -> SNat n -> Perm is'
-> Maybe (AllElem' (Count i n) is' :~: True)
- provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
- case cmpNat i n of
- EQI -> Just Refl
- LTI | Refl <- lemCount i n
- , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ provePerm2 = \i n perm ->
+ case snatCompare i n of
+ SEQ -> Just Refl
+ SLT | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SS i) n perm
-> checkElem i perm
| otherwise -> Nothing
- GTI -> error "unreachable"
+ SGT -> error "unreachable"
where
checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
checkElem _ PNil = Nothing
- checkElem i@SNat (PCons k@SNat perm :: Perm is') =
- case sameNat i k of
+ checkElem i (PCons k perm :: Perm is') =
+ case testEquality i k of
Just Refl -> Just Refl
Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
| otherwise -> Nothing
@@ -117,7 +114,7 @@ permCheckPermutation = \p k ->
-- | Utility class for generating permutations from type class information.
class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
-instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = knownNat `PCons` makePerm
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -139,13 +136,13 @@ type AllElem as bs = Assert (AllElem' as bs)
type family Count i n where
Count n n = '[]
- Count i n = i : Count (i + 1) n
+ Count i n = i : Count (S i) n
-type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+type IsPermutation as = (AllElem as (Count Z (Rank as)), AllElem (Count Z (Rank as)) as)
type family Index i sh where
- Index 0 (n : sh) = n
- Index i (_ : sh) = Index (i - 1) sh
+ Index Z (n : sh) = n
+ Index (S i) (_ : sh) = Index i sh
type family Permute is sh where
Permute '[] sh = '[]
@@ -178,9 +175,7 @@ listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
+listxIndex p pT (SS i) (_ ::% sh) = listxIndex p pT i sh
listxIndex _ _ _ ZX = error "Index into empty shape"
listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
@@ -199,7 +194,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
@@ -236,23 +231,23 @@ permInverse = \perm k ->
where
toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
toHList [] k = k PNil
- toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
+ toHList (n : ns) k = toHList ns $ \l -> withSomeSNat n $ \sn -> k (PCons sn l)
provePermInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh =
- ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+ testEquality (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
MapSucc '[] = '[]
- MapSucc (i : is) = i + 1 : MapSucc is
+ MapSucc (i : is) = S i : MapSucc is
-permShift1 :: Perm l -> Perm (0 : MapSucc l)
-permShift1 = (SNat @0 `PCons`) . permMapSucc
+permShift1 :: Perm l -> Perm (Z : MapSucc l)
+permShift1 = (SZ `PCons`) . permMapSucc
where
permMapSucc :: Perm l -> Perm (MapSucc l)
permMapSucc PNil = PNil
- permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+ permMapSucc (i `PCons` ns) = SS i `PCons` permMapSucc ns
-- * Lemmas
@@ -266,8 +261,3 @@ lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
lemRankDropLen ZKX PNil = Refl
lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
lemRankDropLen (_ :!% _) PNil = Refl
-lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
-
-lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
- -> Index (i + 1) (a : l) :~: Index i l
-lemIndexSucc _ _ _ = unsafeCoerceRefl