diff options
Diffstat (limited to 'src/Data/Array/Nested/Shape.hs')
-rw-r--r-- | src/Data/Array/Nested/Shape.hs | 467 |
1 files changed, 0 insertions, 467 deletions
diff --git a/src/Data/Array/Nested/Shape.hs b/src/Data/Array/Nested/Shape.hs deleted file mode 100644 index 774b4bd..0000000 --- a/src/Data/Array/Nested/Shape.hs +++ /dev/null @@ -1,467 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Shape where - -import Data.Array.Mixed.Types -import Data.Coerce (coerce) -import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Kind (Type, Constraint) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import GHC.IsList (IsList) -import GHC.IsList qualified as IsList -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape - - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - -listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR sh' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrToSNat :: ListR n i -> SNat n -listrToSNat ZR = SNat -listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrToSNat sperm, listrToSNat sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l - -ixrZero :: SNat n -> IIxR n -ixrZero SZ = ZIR -ixrZero (SS n) = 0 :.: ixrZero n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - -ixrToSNat :: IxR n i -> SNat n -ixrToSNat (IxR sh) = listrToSNat sh - -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: (ShR sh) = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) - --- | The number of elements in an array described by this shape. -shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh - -shrToSNat :: ShR n i -> SNat n -shrToSNat (ShR sh) = listrToSNat sh - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList = go (SNat @n) - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error "IsList(ListR): Mismatched list length" - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where - type Item (IxR n i) = i - fromList = IxR . IsList.fromList - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList - - -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) -infixr 3 ::$ - -instance (forall n. Show (f n)) => Show (ListS sh f) where - showsPrec _ = listsShow shows - -data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) -listsUncons ZS = Nothing - -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS -listsShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListS sh' f -> ShowS - go _ ZS = id - go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is - -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f -listsAppend ZS idx' = idx' -listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' - -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f -listsTakeLenPerm PNil _ = ZS -listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh -listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLenPerm PNil sh = sh -listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh -listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh - --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" - -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) - -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) - -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) - -applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f -applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) - -applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -applyPermIxS = coerce (applyPermS @(Const i)) - -applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -applyPermShS = coerce (applyPermS @SNat) - - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) - :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) - => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) -infixr 3 :.$ - -{-# COMPLETE ZIS, (:.$) #-} - -type IIxS sh = IxS sh Int - -instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l - -instance Functor (IxS sh) where - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l - -ixsZero :: ShS sh -> IIxS sh -ixsZero ZSS = ZIS -ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - - --- | The shape of a shape-typed array given as a list of 'SNat' values. -type role ShS nominal -type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord) - -pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS - -pattern (:$$) - :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) - => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) - -infixr 3 :$$ - -{-# COMPLETE ZSS, (:$$) #-} - -instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l - -shsLength :: ShS sh -> Int -shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) - -shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh - -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - -shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShS :: [Nat] -> Constraint -class KnownShS sh where knownShS :: ShS sh -instance KnownShS '[] where knownShS = ZSS -instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS - - --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listsToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShS sh => IsList (IxS sh i) where - type Item (IxS sh i) = i - fromList = IxS . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and values are checked at runtime. -instance KnownShS sh => IsList (ShS sh) where - type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = shsToList |