{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) import Data.Array.Mixed.Types import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Functor.Const import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Nested.Mixed.Shape type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where ZS :: ListS '[] f -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) infixr 3 ::$ instance (forall n. Show (f n)) => Show (ListS sh f) where showsPrec _ = listsShow shows instance (forall m. NFData (f m)) => NFData (ListS n f) where rnf ZS = () rnf (x ::$ l) = rnf x `seq` rnf l data UnconsListSRes f sh1 = forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to -- 'testEquality', except on the penultimate type parameter. listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') listsEqType ZS ZS = Just Refl listsEqType (n ::$ sh) (m ::$ sh') | Just Refl <- testEquality n m , Just Refl <- listsEqType sh sh' = Just Refl listsEqType _ _ = Nothing -- | This checks whether the two lists actually contain equal values. This is -- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ -- in the @some@ package (except on the penultimate type parameter). listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') listsEqual ZS ZS = Just Refl listsEqual (n ::$ sh) (m ::$ sh') | Just Refl <- testEquality n m , n == m , Just Refl <- listsEqual sh sh' = Just Refl listsEqual _ _ = Nothing listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m listsFold _ ZS = mempty listsFold f (x ::$ xs) = f x <> listsFold f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListS sh' f -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsLength :: ListS sh f -> Int listsLength = getSum . listsFold (\_ -> Sum 1) listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) listsToList :: ListS sh (Const i) -> [i] listsToList ZS = [] listsToList (Const i ::$ is) = i : listsToList is listsHead :: ListS (n : sh) f -> f n listsHead (i ::$ _) = i listsTail :: ListS (n : sh) f -> ListS sh f listsTail (_ ::$ sh) = sh listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh listsInit (_ ::$ ZS) = ZS listsLast :: ListS (n : sh) f -> f (Last (n : sh)) listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh listsLast (n ::$ ZS) = n listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex (Proxy @is') (Proxy @sh) i sh of (item, SNat) -> item ::$ listsPermute is sh -- TODO: remove this SNat when the KnownNat constaint in ListS is removed listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) listsIndex _ _ SZ (n ::$ _) = (n, SNat) listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) deriving (Eq, Ord, Generic) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) where i :.$ IxS shl = IxS (Const i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} type IIxS sh = IxS sh Int instance Show i => Show (IxS sh i) where showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l instance Functor (IxS sh) where fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where foldMap f (IxS l) = listsFold (f . getConst) l instance NFData i => NFData (IxS sh i) ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank (IxS l) = listsRank l ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh ixCvtXS ZSS ZIX = ZIS ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX ZIS = ZIX ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) deriving (Eq, Ord, Generic) pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS pattern (:$$) :: forall {sh1}. forall n sh. (KnownNat n, n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) where i :$$ ShS shl = ShS (i ::$ shl) infixr 3 :$$ {-# COMPLETE ZSS, (:$$) #-} instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l instance NFData (ShS sh) where rnf (ShS ZS) = () rnf (ShS (SNat ::$ l)) = rnf (ShS l) instance TestEquality ShS where testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int shsLength (ShS l) = listsLength l shsRank :: ShS sh -> SNat (Rank sh) shsRank (ShS l) = listsRank l shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = castWith (subst1 (lem Refl)) $ n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) idx) where lem :: forall sh1 sh' n. Just n : sh1 :~: MapJust sh' -> n : Tail sh' :~: sh' lem Refl = unsafeCoerceRefl shCvtXS' (SUnknown _ :$% _) = error "impossible" shCvtSX :: ShS sh -> IShX (MapJust sh) shCvtSX ZSS = ZSX shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list shsTail :: ShS (n : sh) -> ShS sh shsTail (ShS list) = ShS (listsTail list) shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) shsInit (ShS list) = ShS (listsInit list) shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) shsLast (ShS list) = listsLast list shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') shsAppend = coerce (listsAppend @_ @SNat) shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLen = coerce (listsTakeLenPerm @SNat) shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) shsPermutePrefix = coerce (listsPermutePrefix @SNat) type family Product sh where Product '[] = 1 Product (n : ns) = n * Product ns shsProduct :: ShS sh -> SNat (Product sh) shsProduct ZSS = SNat shsProduct (n :$$ sh) = n `snatMul` shsProduct sh -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShS :: [Nat] -> Constraint class KnownShS sh where knownShS :: ShS sh instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r withKnownShS k = withDict @(KnownShS sh) k shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i fromList topl = go (knownShS @sh) topl where go :: ShS sh' -> [i] -> ListS sh' (Const i) go ZSS [] = ZS go (_ :$$ sh) (i : is) = Const i ::$ go sh is go _ _ = error $ "IsList(ListS): Mismatched list length (type says " ++ show (shsLength (knownShS @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShS sh => IsList (IxS sh i) where type Item (IxS sh i) = i fromList = IxS . IsList.fromList toList = Foldable.toList -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList topl = ShS (go (knownShS @sh) topl) where go :: ShS sh' -> [Int] -> ListS sh' SNat go ZSS [] = ZS go (sn :$$ sh) (i : is) | i == fromSNat' sn = sn ::$ go sh is | otherwise = error $ "IsList(ShS): Value does not match typing (type says " ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" go _ _ = error $ "IsList(ShS): Mismatched list length (type says " ++ show (shsLength (knownShS @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = shsToList