diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-11 00:11:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-15 11:06:40 +0100 |
commit | e6c20868375d2b7f6b31808844e1b48f78bca069 (patch) | |
tree | 5e3c3efa5c61eb11a28b486bccbbcac823a36614 /src/Data/Array/Mixed/Permutation.hs | |
parent | c705bb4cf76d2e80f3e9ed900f901b697b378f79 (diff) |
WIP half-peano SNatspeano-snat
Diffstat (limited to 'src/Data/Array/Mixed/Permutation.hs')
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 88 |
1 files changed, 39 insertions, 49 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 331d5e0..e85e67f 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -13,8 +13,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Permutation where import Data.Coerce (coerce) @@ -24,13 +22,12 @@ import Data.Maybe (fromMaybe) import Data.Proxy import Data.Type.Bool import Data.Type.Equality -import Data.Type.Ord import GHC.TypeError -import GHC.TypeLits -import GHC.TypeNats qualified as TN +import Numeric.Natural import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +import Data.SNat.Peano -- * Permutations @@ -46,18 +43,19 @@ deriving instance Show (Perm list) deriving instance Eq (Perm list) permRank :: Perm list -> SNat (Rank list) -permRank PNil = SNat -permRank (_ `PCons` l) | SNat <- permRank l = SNat +permRank PNil = SZ +permRank (_ `PCons` l) = SS (permRank l) permFromList :: [Int] -> (forall list. Perm list -> r) -> r permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x +permFromList (x : xs) k = + withSomeSNat' x $ \sn -> + permFromList xs $ \list -> + k (sn `PCons` list) permToList :: Perm list -> [Natural] permToList PNil = mempty -permToList (x `PCons` l) = TN.fromSNat x : permToList l +permToList (x `PCons` l) = fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList @@ -68,48 +66,47 @@ permToList' = map fromIntegral . permToList permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r permCheckPermutation = \p k -> let n = permRank p - in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of + in case (provePerm1 (Proxy @list) n p, provePerm2 SZ n p) of (Just Refl, Just Refl) -> Just k _ -> Nothing where - lemElemCount :: (0 <= n, Compare n m ~ LT) - => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount :: (Z <= n, n < m) + => proxy n -> proxy m -> Elem n (Count Z m) :~: True lemElemCount _ _ = unsafeCoerceRefl - lemCount :: (OrdCond (Compare i n) True False True ~ True) - => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount :: i < n => proxy i -> proxy n -> Count i n :~: i : Count (S i) n lemCount _ _ = unsafeCoerceRefl lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True lemElem _ _ = unsafeCoerceRefl provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + -> Maybe (AllElem' is' (Count Z (Rank isTop)) :~: True) provePerm1 _ _ PNil = Just (Refl) - provePerm1 p rtop@SNat (PCons sn@SNat perm) + provePerm1 p rtop (PCons sn perm) | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + = case (snatCompare SZ sn, snatCompare sn rtop) of + (SLT, SLT) | Refl <- lemElemCount sn rtop -> Just Refl + (SEQ, SLT) | Refl <- lemElemCount sn rtop -> Just Refl _ -> Nothing | otherwise = Nothing provePerm2 :: SNat i -> SNat n -> Perm is' -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + provePerm2 = \i n perm -> + case snatCompare i n of + SEQ -> Just Refl + SLT | Refl <- lemCount i n + , Just Refl <- provePerm2 (SS i) n perm -> checkElem i perm | otherwise -> Nothing - GTI -> error "unreachable" + SGT -> error "unreachable" where checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) checkElem _ PNil = Nothing - checkElem i@SNat (PCons k@SNat perm :: Perm is') = - case sameNat i k of + checkElem i (PCons k perm :: Perm is') = + case testEquality i k of Just Refl -> Just Refl Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl | otherwise -> Nothing @@ -117,7 +114,7 @@ permCheckPermutation = \p k -> -- | Utility class for generating permutations from type class information. class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil -instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = knownNat `PCons` makePerm -- | Untyped permutations for ranked arrays type PermR = [Int] @@ -139,13 +136,13 @@ type AllElem as bs = Assert (AllElem' as bs) type family Count i n where Count n n = '[] - Count i n = i : Count (i + 1) n + Count i n = i : Count (S i) n -type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) +type IsPermutation as = (AllElem as (Count Z (Rank as)), AllElem (Count Z (Rank as)) as) type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh + Index Z (n : sh) = n + Index (S i) (_ : sh) = Index i sh type family Permute is sh where Permute '[] sh = '[] @@ -178,9 +175,7 @@ listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh +listxIndex p pT (SS i) (_ ::% sh) = listxIndex p pT i sh listxIndex _ _ _ ZX = error "Index into empty shape" listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f @@ -199,7 +194,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) +ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) @@ -236,23 +231,23 @@ permInverse = \perm k -> where toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r toHList [] k = k PNil - toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + toHList (n : ns) k = toHList ns $ \l -> withSomeSNat n $ \sn -> k (PCons sn l) provePermInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = - ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh + testEquality (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where MapSucc '[] = '[] - MapSucc (i : is) = i + 1 : MapSucc is + MapSucc (i : is) = S i : MapSucc is -permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc +permShift1 :: Perm l -> Perm (Z : MapSucc l) +permShift1 = (SZ `PCons`) . permMapSucc where permMapSucc :: Perm l -> Perm (MapSucc l) permMapSucc PNil = PNil - permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + permMapSucc (i `PCons` ns) = SS i `PCons` permMapSucc ns -- * Lemmas @@ -266,8 +261,3 @@ lemRankDropLen :: forall is sh. (Rank is <= Rank sh) lemRankDropLen ZKX PNil = Refl lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!% _) PNil = Refl -lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l - -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerceRefl |