diff options
author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-16 23:50:45 +0200 |
---|---|---|
committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-16 23:50:45 +0200 |
commit | 7bc9bc6ddbcee4f4193d5c79db92186f12ce3eb2 (patch) | |
tree | 0f240a22377af2048c7c7a619fdddb9c2a007035 /src/Data/Array/Mixed | |
parent | edc711c4cfb9c07466e7d242c6949131acc56f71 (diff) |
Move modules Permutation and Types
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 273 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 134 |
3 files changed, 2 insertions, 409 deletions
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index cfb7bc6..ded6af5 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -12,9 +12,9 @@ import Data.Proxy import Data.Type.Equality import GHC.TypeLits -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types -- * Reasoning helpers diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs deleted file mode 100644 index ef0afe3..0000000 --- a/src/Data/Array/Mixed/Permutation.hs +++ /dev/null @@ -1,273 +0,0 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Permutation where - -import Data.Coerce (coerce) -import Data.Functor.Const -import Data.List (sort) -import Data.Maybe (fromMaybe) -import Data.Proxy -import Data.Type.Bool -import Data.Type.Equality -import Data.Type.Ord -import GHC.TypeError -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Types -import Data.Array.Nested.Mixed.Shape - - --- * Permutations - --- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. -data Perm list where - PNil :: Perm '[] - PCons :: SNat a -> Perm l -> Perm (a : l) -infixr 5 `PCons` -deriving instance Show (Perm list) -deriving instance Eq (Perm list) - -permRank :: Perm list -> SNat (Rank list) -permRank PNil = SNat -permRank (_ `PCons` l) | SNat <- permRank l = SNat - -permFromList :: [Int] -> (forall list. Perm list -> r) -> r -permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x - -permToList :: Perm list -> [Natural] -permToList PNil = mempty -permToList (x `PCons` l) = TN.fromSNat x : permToList l - -permToList' :: Perm list -> [Int] -permToList' = map fromIntegral . permToList - --- | When called as @permCheckPermutation p k@, if @p@ is a permutation of --- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, --- then @Nothing@ is returned. -permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r -permCheckPermutation = \p k -> - let n = permRank p - in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of - (Just Refl, Just Refl) -> Just k - _ -> Nothing - where - lemElemCount :: (0 <= n, Compare n m ~ LT) - => proxy n -> proxy m -> Elem n (Count 0 m) :~: True - lemElemCount _ _ = unsafeCoerceRefl - - lemCount :: (OrdCond (Compare i n) True False True ~ True) - => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n - lemCount _ _ = unsafeCoerceRefl - - lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True - lemElem _ _ = unsafeCoerceRefl - - provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) - provePerm1 _ _ PNil = Just Refl - provePerm1 p rtop@SNat (PCons sn@SNat perm) - | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - _ -> Nothing - | otherwise - = Nothing - - provePerm2 :: SNat i -> SNat n -> Perm is' - -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm - -> checkElem i perm - | otherwise -> Nothing - GTI -> error "unreachable" - where - checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) - checkElem _ PNil = Nothing - checkElem i@SNat (PCons k@SNat perm :: Perm is') = - case sameNat i k of - Just Refl -> Just Refl - Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl - | otherwise -> Nothing - --- | Utility class for generating permutations from type class information. -class KnownPerm l where makePerm :: Perm l -instance KnownPerm '[] where makePerm = PNil -instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm - --- | Untyped permutations for ranked arrays -type PermR = [Int] - - --- ** Applying permutations - -type family Elem x l where - Elem x '[] = 'False - Elem x (x : _) = 'True - Elem x (_ : ys) = Elem x ys - -type family AllElem' as bs where - AllElem' '[] bs = 'True - AllElem' (a : as) bs = Elem a bs && AllElem' as bs - -type AllElem as bs = Assert (AllElem' as bs) - (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) - -type family Count i n where - Count n n = '[] - Count i n = i : Count (i + 1) n - -type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) - -type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh - -type family Permute is sh where - Permute '[] sh = '[] - Permute (i : is) sh = Index i sh : Permute is sh - -type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh - -type family TakeLen ref l where - TakeLen '[] l = '[] - TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs - -type family DropLen ref l where - DropLen '[] l = l - DropLen (_ : ref) (_ : xs) = DropLen ref xs - -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen PNil _ = ZX -listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen PNil sh = sh -listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh - -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" - -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) - -ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) - -ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - - --- * Operations on permutations - -permInverse :: Perm is - -> (forall is'. - IsPermutation is' - => Perm is' - -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) - -> r) - -> r -permInverse = \perm k -> - genPerm perm $ \(invperm :: Perm is') -> - fromMaybe - (error $ "permInverse: did not generate permutation? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) - (permCheckPermutation invperm - (k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm))) - where - genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r - genPerm perm = - let permList = permToList' perm - in toHList $ map snd (sort (zip permList [0..])) - where - toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r - toHList [] k = k PNil - toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - - provePermInverse :: Perm is -> Perm is' -> StaticShX sh - -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = - ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh - -type family MapSucc is where - MapSucc '[] = '[] - MapSucc (i : is) = i + 1 : MapSucc is - -permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc - where - permMapSucc :: Perm l -> Perm (MapSucc l) - permMapSucc PNil = PNil - permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns - - --- * Lemmas - -lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is -lemRankPermute _ PNil = Refl -lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl - -lemRankDropLen :: forall is sh. (Rank is <= Rank sh) - => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKX PNil = Refl -lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!% _) PNil = Refl -lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l - -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs deleted file mode 100644 index 3f5b1e7..0000000 --- a/src/Data/Array/Mixed/Types.hs +++ /dev/null @@ -1,134 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Types ( - -- * Reified evidence of a type class - Dict(..), - - -- * Type-level naturals - pattern SZ, pattern SS, - fromSNat', sameNat', - snatPlus, snatMinus, snatMul, - snatSucc, - - -- * Type-level lists - type (++), - Replicate, - lemReplicateSucc, - MapJust, - Head, - Tail, - Init, - Last, - - -- * Unsafe - unsafeCoerceRefl, -) where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits -import GHC.TypeNats qualified as TN -import Unsafe.Coerce qualified - - --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) -sameNat' n@SNat m@SNat = sameNat n m - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - --- This should be a function in base -snatPlus :: SNat n -> SNat m -> SNat (n + m) -snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMinus :: SNat n -> SNat m -> SNat (n - m) -snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMul :: SNat n -> SNat m -> SNat (n * m) -snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - -type family Head l where - Head (x : _) = x - -type family Tail l where - Tail (_ : xs) = xs - -type family Init l where - Init (x : y : xs) = x : Init (y : xs) - Init '[x] = '[] - -type family Last l where - Last (x : y : xs) = Last (y : xs) - Last '[x] = x - - --- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to --- only typecheck for actual type equalities. One cannot, e.g. accidentally --- write this: --- --- @ --- foo :: Proxy a -> Proxy b -> a :~: b --- foo = unsafeCoerceRefl --- @ --- --- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', --- but would have resulted in interesting memory errors at runtime. -unsafeCoerceRefl :: a :~: b -unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl |