aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-05-16 23:50:45 +0200
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-05-16 23:50:45 +0200
commit7bc9bc6ddbcee4f4193d5c79db92186f12ce3eb2 (patch)
tree0f240a22377af2048c7c7a619fdddb9c2a007035 /src/Data/Array/Mixed
parentedc711c4cfb9c07466e7d242c6949131acc56f71 (diff)
Move modules Permutation and Types
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs4
-rw-r--r--src/Data/Array/Mixed/Permutation.hs273
-rw-r--r--src/Data/Array/Mixed/Types.hs134
3 files changed, 2 insertions, 409 deletions
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
index cfb7bc6..ded6af5 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -12,9 +12,9 @@ import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Types
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
-- * Reasoning helpers
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
deleted file mode 100644
index ef0afe3..0000000
--- a/src/Data/Array/Mixed/Permutation.hs
+++ /dev/null
@@ -1,273 +0,0 @@
-{-# LANGUAGE ConstraintKinds #-}
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Permutation where
-
-import Data.Coerce (coerce)
-import Data.Functor.Const
-import Data.List (sort)
-import Data.Maybe (fromMaybe)
-import Data.Proxy
-import Data.Type.Bool
-import Data.Type.Equality
-import Data.Type.Ord
-import GHC.TypeError
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Mixed.Shape
-
-
--- * Permutations
-
--- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
-data Perm list where
- PNil :: Perm '[]
- PCons :: SNat a -> Perm l -> Perm (a : l)
-infixr 5 `PCons`
-deriving instance Show (Perm list)
-deriving instance Eq (Perm list)
-
-permRank :: Perm list -> SNat (Rank list)
-permRank PNil = SNat
-permRank (_ `PCons` l) | SNat <- permRank l = SNat
-
-permFromList :: [Int] -> (forall list. Perm list -> r) -> r
-permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
-
-permToList :: Perm list -> [Natural]
-permToList PNil = mempty
-permToList (x `PCons` l) = TN.fromSNat x : permToList l
-
-permToList' :: Perm list -> [Int]
-permToList' = map fromIntegral . permToList
-
--- | When called as @permCheckPermutation p k@, if @p@ is a permutation of
--- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't,
--- then @Nothing@ is returned.
-permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r
-permCheckPermutation = \p k ->
- let n = permRank p
- in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
- (Just Refl, Just Refl) -> Just k
- _ -> Nothing
- where
- lemElemCount :: (0 <= n, Compare n m ~ LT)
- => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
- lemElemCount _ _ = unsafeCoerceRefl
-
- lemCount :: (OrdCond (Compare i n) True False True ~ True)
- => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
- lemCount _ _ = unsafeCoerceRefl
-
- lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
- lemElem _ _ = unsafeCoerceRefl
-
- provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
- -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
- provePerm1 _ _ PNil = Just Refl
- provePerm1 p rtop@SNat (PCons sn@SNat perm)
- | Just Refl <- provePerm1 p rtop perm
- = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
- (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- _ -> Nothing
- | otherwise
- = Nothing
-
- provePerm2 :: SNat i -> SNat n -> Perm is'
- -> Maybe (AllElem' (Count i n) is' :~: True)
- provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
- case cmpNat i n of
- EQI -> Just Refl
- LTI | Refl <- lemCount i n
- , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
- -> checkElem i perm
- | otherwise -> Nothing
- GTI -> error "unreachable"
- where
- checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
- checkElem _ PNil = Nothing
- checkElem i@SNat (PCons k@SNat perm :: Perm is') =
- case sameNat i k of
- Just Refl -> Just Refl
- Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
- | otherwise -> Nothing
-
--- | Utility class for generating permutations from type class information.
-class KnownPerm l where makePerm :: Perm l
-instance KnownPerm '[] where makePerm = PNil
-instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
-
--- | Untyped permutations for ranked arrays
-type PermR = [Int]
-
-
--- ** Applying permutations
-
-type family Elem x l where
- Elem x '[] = 'False
- Elem x (x : _) = 'True
- Elem x (_ : ys) = Elem x ys
-
-type family AllElem' as bs where
- AllElem' '[] bs = 'True
- AllElem' (a : as) bs = Elem a bs && AllElem' as bs
-
-type AllElem as bs = Assert (AllElem' as bs)
- (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
-
-type family Count i n where
- Count n n = '[]
- Count i n = i : Count (i + 1) n
-
-type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
-
-type family Index i sh where
- Index 0 (n : sh) = n
- Index i (_ : sh) = Index (i - 1) sh
-
-type family Permute is sh where
- Permute '[] sh = '[]
- Permute (i : is) sh = Index i sh : Permute is sh
-
-type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
-
-type family TakeLen ref l where
- TakeLen '[] l = '[]
- TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
-
-type family DropLen ref l where
- DropLen '[] l = l
- DropLen (_ : ref) (_ : xs) = DropLen ref xs
-
-listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
-listxTakeLen PNil _ = ZX
-listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
-listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
-
-listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
-listxDropLen PNil sh = sh
-listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
-listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
-
-listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
-listxPermute PNil _ = ZX
-listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
-
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
-listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
-listxIndex _ _ _ ZX = error "Index into empty shape"
-
-listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
-listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
-
-ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
-ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
-
-ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
-
-ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
-
-ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
-
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
-
-ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
-
-shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-
-
--- * Operations on permutations
-
-permInverse :: Perm is
- -> (forall is'.
- IsPermutation is'
- => Perm is'
- -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
- -> r)
- -> r
-permInverse = \perm k ->
- genPerm perm $ \(invperm :: Perm is') ->
- fromMaybe
- (error $ "permInverse: did not generate permutation? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)
- (permCheckPermutation invperm
- (k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
- Just eq -> eq
- Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)))
- where
- genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
- genPerm perm =
- let permList = permToList' perm
- in toHList $ map snd (sort (zip permList [0..]))
- where
- toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
- toHList [] k = k PNil
- toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
-
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
- -> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
- ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
-
-type family MapSucc is where
- MapSucc '[] = '[]
- MapSucc (i : is) = i + 1 : MapSucc is
-
-permShift1 :: Perm l -> Perm (0 : MapSucc l)
-permShift1 = (SNat @0 `PCons`) . permMapSucc
- where
- permMapSucc :: Perm l -> Perm (MapSucc l)
- permMapSucc PNil = PNil
- permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
-
-
--- * Lemmas
-
-lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
-lemRankPermute _ PNil = Refl
-lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
-
-lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
- => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
-lemRankDropLen ZKX PNil = Refl
-lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
-lemRankDropLen (_ :!% _) PNil = Refl
-lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
-
-lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
- -> Index (i + 1) (a : l) :~: Index i l
-lemIndexSucc _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
deleted file mode 100644
index 3f5b1e7..0000000
--- a/src/Data/Array/Mixed/Types.hs
+++ /dev/null
@@ -1,134 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Types (
- -- * Reified evidence of a type class
- Dict(..),
-
- -- * Type-level naturals
- pattern SZ, pattern SS,
- fromSNat', sameNat',
- snatPlus, snatMinus, snatMul,
- snatSucc,
-
- -- * Type-level lists
- type (++),
- Replicate,
- lemReplicateSucc,
- MapJust,
- Head,
- Tail,
- Init,
- Last,
-
- -- * Unsafe
- unsafeCoerceRefl,
-) where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-import Unsafe.Coerce qualified
-
-
--- | Evidence for the constraint @c a@.
-data Dict c a where
- Dict :: c a => Dict c a
-
-fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
-
-sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
-sameNat' n@SNat m@SNat = sameNat n m
-
-pattern SZ :: () => (n ~ 0) => SNat n
-pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
- where SZ = SNat
-
-pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
-pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
- where SS = snatSucc
-
-{-# COMPLETE SZ, SS #-}
-
-snatSucc :: SNat n -> SNat (n + 1)
-snatSucc SNat = SNat
-
-data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
-snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
-snatPred snp1 =
- withKnownNat snp1 $
- case cmpNat (Proxy @1) (Proxy @np1) of
- LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- GTI -> Nothing
-
--- This should be a function in base
-snatPlus :: SNat n -> SNat m -> SNat (n + m)
-snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
-
--- This should be a function in base
-snatMinus :: SNat n -> SNat m -> SNat (n - m)
-snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce
-
--- This should be a function in base
-snatMul :: SNat n -> SNat m -> SNat (n * m)
-snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
-
-
--- | Type-level list append.
-type family l1 ++ l2 where
- '[] ++ l2 = l2
- (x : xs) ++ l2 = x : xs ++ l2
-
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerceRefl
-
-type family MapJust l where
- MapJust '[] = '[]
- MapJust (x : xs) = Just x : MapJust xs
-
-type family Head l where
- Head (x : _) = x
-
-type family Tail l where
- Tail (_ : xs) = xs
-
-type family Init l where
- Init (x : y : xs) = x : Init (y : xs)
- Init '[x] = '[]
-
-type family Last l where
- Last (x : y : xs) = Last (y : xs)
- Last '[x] = x
-
-
--- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
--- only typecheck for actual type equalities. One cannot, e.g. accidentally
--- write this:
---
--- @
--- foo :: Proxy a -> Proxy b -> a :~: b
--- foo = unsafeCoerceRefl
--- @
---
--- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce',
--- but would have resulted in interesting memory errors at runtime.
-unsafeCoerceRefl :: a :~: b
-unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl