diff options
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 273 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Types.hs | 134 | 
3 files changed, 2 insertions, 409 deletions
| diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index cfb7bc6..ded6af5 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -12,9 +12,9 @@ import Data.Proxy  import Data.Type.Equality  import GHC.TypeLits -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types  import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types  -- * Reasoning helpers diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs deleted file mode 100644 index ef0afe3..0000000 --- a/src/Data/Array/Mixed/Permutation.hs +++ /dev/null @@ -1,273 +0,0 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Permutation where - -import Data.Coerce (coerce) -import Data.Functor.Const -import Data.List (sort) -import Data.Maybe (fromMaybe) -import Data.Proxy -import Data.Type.Bool -import Data.Type.Equality -import Data.Type.Ord -import GHC.TypeError -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Types -import Data.Array.Nested.Mixed.Shape - - --- * Permutations - --- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. -data Perm list where -  PNil :: Perm '[] -  PCons :: SNat a -> Perm l -> Perm (a : l) -infixr 5 `PCons` -deriving instance Show (Perm list) -deriving instance Eq (Perm list) - -permRank :: Perm list -> SNat (Rank list) -permRank PNil = SNat -permRank (_ `PCons` l) | SNat <- permRank l = SNat - -permFromList :: [Int] -> (forall list. Perm list -> r) -> r -permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case -  Just sn -> permFromList xs $ \list -> k (sn `PCons` list) -  Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x - -permToList :: Perm list -> [Natural] -permToList PNil = mempty -permToList (x `PCons` l) = TN.fromSNat x : permToList l - -permToList' :: Perm list -> [Int] -permToList' = map fromIntegral . permToList - --- | When called as @permCheckPermutation p k@, if @p@ is a permutation of --- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, --- then @Nothing@ is returned. -permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r -permCheckPermutation = \p k -> -  let n = permRank p -  in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of -       (Just Refl, Just Refl) -> Just k -       _ -> Nothing -  where -    lemElemCount :: (0 <= n, Compare n m ~ LT) -                 => proxy n -> proxy m -> Elem n (Count 0 m) :~: True -    lemElemCount _ _ = unsafeCoerceRefl - -    lemCount :: (OrdCond (Compare i n) True False True ~ True) -             => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n -    lemCount _ _ = unsafeCoerceRefl - -    lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True -    lemElem _ _ = unsafeCoerceRefl - -    provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' -               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) -    provePerm1 _ _ PNil = Just Refl -    provePerm1 p rtop@SNat (PCons sn@SNat perm) -      | Just Refl <- provePerm1 p rtop perm -      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of -          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl -          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl -          _ -> Nothing -      | otherwise -      = Nothing - -    provePerm2 :: SNat i -> SNat n -> Perm is' -               -> Maybe (AllElem' (Count i n) is' :~: True) -    provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> -      case cmpNat i n of -        EQI -> Just Refl -        LTI | Refl <- lemCount i n -            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm -            -> checkElem i perm -            | otherwise -> Nothing -        GTI -> error "unreachable" -      where -        checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) -        checkElem _ PNil = Nothing -        checkElem i@SNat (PCons k@SNat perm :: Perm is') = -          case sameNat i k of -            Just Refl -> Just Refl -            Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl -                    | otherwise -> Nothing - --- | Utility class for generating permutations from type class information. -class KnownPerm l where makePerm :: Perm l -instance KnownPerm '[] where makePerm = PNil -instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm - --- | Untyped permutations for ranked arrays -type PermR = [Int] - - --- ** Applying permutations - -type family Elem x l where -  Elem x '[] = 'False -  Elem x (x : _) = 'True -  Elem x (_ : ys) = Elem x ys - -type family AllElem' as bs where -  AllElem' '[] bs = 'True -  AllElem' (a : as) bs = Elem a bs && AllElem' as bs - -type AllElem as bs = Assert (AllElem' as bs) -  (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) - -type family Count i n where -  Count n n = '[] -  Count i n = i : Count (i + 1) n - -type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) - -type family Index i sh where -  Index 0 (n : sh) = n -  Index i (_ : sh) = Index (i - 1) sh - -type family Permute is sh where -  Permute '[] sh = '[] -  Permute (i : is) sh = Index i sh : Permute is sh - -type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh - -type family TakeLen ref l where -  TakeLen '[] l = '[] -  TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs - -type family DropLen ref l where -  DropLen '[] l = l -  DropLen (_ : ref) (_ : xs) = DropLen ref xs - -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen PNil _ = ZX -listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen PNil sh = sh -listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = -  listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh - -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) -  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" - -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) - -ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) - -ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - - --- * Operations on permutations - -permInverse :: Perm is -            -> (forall is'. -                     IsPermutation is' -                  => Perm is' -                  -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) -                  -> r) -            -> r -permInverse = \perm k -> -  genPerm perm $ \(invperm :: Perm is') -> -    fromMaybe -      (error $ "permInverse: did not generate permutation? perm = " ++ show perm -               ++ " ; invperm = " ++ show invperm) -      (permCheckPermutation invperm -        (k invperm -           (\ssh -> case provePermInverse perm invperm ssh of -                      Just eq -> eq -                      Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm -                                             ++ " ; invperm = " ++ show invperm))) -  where -    genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r -    genPerm perm = -      let permList = permToList' perm -      in toHList $ map snd (sort (zip permList [0..])) -      where -        toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r -        toHList [] k = k PNil -        toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - -    provePermInverse :: Perm is -> Perm is' -> StaticShX sh -                     -> Maybe (Permute is' (Permute is sh) :~: sh) -    provePermInverse perm perminv ssh = -      ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh - -type family MapSucc is where -  MapSucc '[] = '[] -  MapSucc (i : is) = i + 1 : MapSucc is - -permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc -  where -    permMapSucc :: Perm l -> Perm (MapSucc l) -    permMapSucc PNil = PNil -    permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns - - --- * Lemmas - -lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is -lemRankPermute _ PNil = Refl -lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl - -lemRankDropLen :: forall is sh. (Rank is <= Rank sh) -               => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKX PNil = Refl -lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!% _) PNil = Refl -lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -             -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs deleted file mode 100644 index 3f5b1e7..0000000 --- a/src/Data/Array/Mixed/Types.hs +++ /dev/null @@ -1,134 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Types ( -  -- * Reified evidence of a type class -  Dict(..), - -  -- * Type-level naturals -  pattern SZ, pattern SS, -  fromSNat', sameNat', -  snatPlus, snatMinus, snatMul, -  snatSucc, - -  -- * Type-level lists -  type (++), -  Replicate, -  lemReplicateSucc, -  MapJust, -  Head, -  Tail, -  Init, -  Last, - -  -- * Unsafe -  unsafeCoerceRefl, -) where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits -import GHC.TypeNats qualified as TN -import Unsafe.Coerce qualified - - --- | Evidence for the constraint @c a@. -data Dict c a where -  Dict :: c a => Dict c a - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) -sameNat' n@SNat m@SNat = sameNat n m - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) -  where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) -  where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = -  withKnownNat snp1 $ -    case cmpNat (Proxy @1) (Proxy @np1) of -      LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) -      EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) -      GTI -> Nothing - --- This should be a function in base -snatPlus :: SNat n -> SNat m -> SNat (n + m) -snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMinus :: SNat n -> SNat m -> SNat (n - m) -snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMul :: SNat n -> SNat m -> SNat (n * m) -snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - - --- | Type-level list append. -type family l1 ++ l2 where -  '[] ++ l2 = l2 -  (x : xs) ++ l2 = x : xs ++ l2 - -type family Replicate n a where -  Replicate 0 a = '[] -  Replicate n a = a : Replicate (n - 1) a - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl - -type family MapJust l where -  MapJust '[] = '[] -  MapJust (x : xs) = Just x : MapJust xs - -type family Head l where -  Head (x : _) = x - -type family Tail l where -  Tail (_ : xs) = xs - -type family Init l where -  Init (x : y : xs) = x : Init (y : xs) -  Init '[x] = '[] - -type family Last l where -  Last (x : y : xs) = Last (y : xs) -  Last '[x] = x - - --- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to --- only typecheck for actual type equalities. One cannot, e.g. accidentally --- write this: --- --- @ --- foo :: Proxy a -> Proxy b -> a :~: b --- foo = unsafeCoerceRefl --- @ --- --- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', --- but would have resulted in interesting memory errors at runtime. -unsafeCoerceRefl :: a :~: b -unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl | 
