{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Shape where import Data.Array.Mixed.Types import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Functor.Const import Data.Kind (Type, Constraint) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where ZR :: ListR 0 i (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) deriving instance Functor (ListR n) deriving instance Foldable (ListR n) infixr 3 ::: instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows data UnconsListRRes i n1 = forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) listrUncons ZR = Nothing listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS listrShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListR n' i -> ShowS go _ ZR = id go prefix (x ::: xs) = showString prefix . f x . go "," xs listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh listrFromList :: [i] -> (forall n. ListR n i -> r) -> r listrFromList [] k = k ZR listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i listrHead ZR = error "unreachable" listrTail :: ListR (n + 1) i -> ListR n i listrTail (_ ::: sh) = sh listrTail ZR = error "unreachable" listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs listrIndex _ ZR = error "k + 1 <= 0" listrLengthSNat :: ListR n i -> SNat n listrLengthSNat ZR = SNat listrLengthSNat (_ ::: (sh :: ListR n i)) = snatSucc (listrLengthSNat sh) listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> listrFromList perm $ \sperm -> case (listrLengthSNat sperm, listrLengthSNat sh) of (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) listrSplitAt SS{} ZR = error "m' + 1 <= 0" applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> case cmpNat (SNat @(idx + 1)) sm of LTI -> listrIndex si l ::: applyPermRFull sm perm l EQI -> listrIndex si l ::: applyPermRFull sm perm l GTI -> error "listrPermutePrefix: Index in permutation out of range" -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) deriving (Eq, Ord) deriving newtype (Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR pattern (:.:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) where i :.: IxR sh = IxR (i ::: sh) infixr 3 :.: {-# COMPLETE ZIR, (:.:) #-} type IIxR n = IxR n Int instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n ixCvtXR :: IIxX sh -> IIxR (Rank sh) ixCvtXR ZIX = ZIR ixCvtXR (n :.% idx) = n :.: ixCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.% ixCvtRX idx) ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list ixrTail :: IxR (n + 1) i -> IxR n i ixrTail (IxR list) = IxR (listrTail list) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) ixrLengthSNat :: IxR n i -> SNat n ixrLengthSNat (IxR sh) = listrLengthSNat sh ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) deriving (Eq, Ord) deriving newtype (Functor, Foldable) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR = ShR ZR pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) where i :$: (ShR sh) = ShR (i ::: sh) infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = listrShow shows l shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n shCvtXR' ZSX = castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) ZSR shCvtXR' (n :$% (idx :: IShX sh)) | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = castWith (subst2 (lem1 @sh Refl)) (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) where lem1 :: forall sh' n' k. k : sh' :~: Replicate n' Nothing -> Rank sh' + 1 :~: n' lem1 Refl = unsafeCoerceRefl lem2 :: k : sh :~: Replicate n Nothing -> sh :~: Replicate (Rank sh) Nothing lem2 Refl = unsafeCoerceRefl shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shCvtRX idx) -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int shrSize ZSR = 1 shrSize (n :$: sh) = n * shrSize sh shrHead :: ShR (n + 1) i -> i shrHead (ShR list) = listrHead list shrTail :: ShR (n + 1) i -> ShR n i shrTail (ShR list) = ShR (listrTail list) shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) shrLengthSNat :: ShR n i -> SNat n shrLengthSNat (ShR sh) = listrLengthSNat sh shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i shrPermutePrefix = coerce (listrPermutePrefix @i) -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ListR n i) where type Item (ListR n i) = i fromList = go (SNat @n) where go :: SNat n' -> [i] -> ListR n' i go SZ [] = ZR go (SS n) (i : is) = i ::: go n is go _ _ = error "IsList(ListR): Mismatched list length" toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IxR n i) where type Item (IxR n i) = i fromList = IxR . IsList.fromList toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where ZS :: ListS '[] f -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) infixr 3 ::$ instance (forall n. Show (f n)) => Show (ListS sh f) where showsPrec _ = listsShow shows data UnconsListSRes f sh1 = forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m listsFold _ ZS = mempty listsFold f (x ::$ xs) = f x <> listsFold f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListS sh' f -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsToList :: ListS sh (Const i) -> [i] listsToList ZS = [] listsToList (Const i ::$ is) = i : listsToList is listsHead :: ListS (n : sh) i -> i n listsHead (i ::$ _) = i listsTail :: ListS (n : sh) i -> ListS sh i listsTail (_ ::$ sh) = sh listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex (Proxy @is') (Proxy @sh) i sh of (item, SNat) -> item ::$ listsPermute is sh -- TODO: remove this SNat when the KnownNat constaint in ListS is removed listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) listsIndex _ _ SZ (n ::$ _) = (n, SNat) listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) deriving (Eq, Ord) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) where i :.$ IxS shl = IxS (Const i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} type IIxS sh = IxS sh Int instance Show i => Show (IxS sh i) where showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l instance Functor (IxS sh) where fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where foldMap f (IxS l) = listsFold (f . getConst) l ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh ixCvtXS ZSS ZIX = ZIS ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX ZIS = ZIX ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) deriving (Eq, Ord) pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS pattern (:$$) :: forall {sh1}. forall n sh. (KnownNat n, n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) where i :$$ ShS shl = ShS (i ::$ shl) infixr 3 :$$ {-# COMPLETE ZSS, (:$$) #-} instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l shsLength :: ShS sh -> Int shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = castWith (subst1 (lem Refl)) $ n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) idx) where lem :: forall sh1 sh' n. Just n : sh1 :~: MapJust sh' -> n : Tail sh' :~: sh' lem Refl = unsafeCoerceRefl shCvtXS' (SUnknown _ :$% _) = error "impossible" shCvtSX :: ShS sh -> IShX (MapJust sh) shCvtSX ZSS = ZSX shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list shsTail :: ShS (n : sh) -> ShS sh shsTail (ShS list) = ShS (listsTail list) shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') shsAppend = coerce (listsAppend @_ @SNat) shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLen = coerce (listsTakeLenPerm @SNat) shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) shsPermutePrefix = coerce (listsPermutePrefix @SNat) type family Product sh where Product '[] = 1 Product (n : ns) = n * Product ns shsProduct :: ShS sh -> SNat (Product sh) shsProduct ZSS = SNat shsProduct (n :$$ sh) = n `snatMul` shsProduct sh -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShS :: [Nat] -> Constraint class KnownShS sh where knownShS :: ShS sh instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i fromList topl = go (knownShS @sh) topl where go :: ShS sh' -> [i] -> ListS sh' (Const i) go ZSS [] = ZS go (_ :$$ sh) (i : is) = Const i ::$ go sh is go _ _ = error $ "IsList(ListS): Mismatched list length (type says " ++ show (shsLength (knownShS @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShS sh => IsList (IxS sh i) where type Item (IxS sh i) = i fromList = IxS . IsList.fromList toList = Foldable.toList -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList topl = ShS (go (knownShS @sh) topl) where go :: ShS sh' -> [Int] -> ListS sh' SNat go ZSS [] = ZS go (sn :$$ sh) (i : is) | i == fromSNat' sn = sn ::$ go sh is | otherwise = error $ "IsList(ShS): Value does not match typing (type says " ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" go _ _ = error $ "IsList(ShS): Mismatched list length (type says " ++ show (shsLength (knownShS @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = shsToList