{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Functor.Const import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -- * Shaped lists -- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be -- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where ZS :: ListS '[] f -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance (forall n. Show (f n)) => Show (ListS sh f) #else instance (forall n. Show (f n)) => Show (ListS sh f) where showsPrec _ = listsShow shows #endif instance (forall m. NFData (f m)) => NFData (ListS n f) where rnf ZS = () rnf (x ::$ l) = rnf x `seq` rnf l data UnconsListSRes f sh1 = forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to -- 'testEquality', except on the penultimate type parameter. listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') listsEqType ZS ZS = Just Refl listsEqType (n ::$ sh) (m ::$ sh') | Just Refl <- testEquality n m , Just Refl <- listsEqType sh sh' = Just Refl listsEqType _ _ = Nothing -- | This checks whether the two lists actually contain equal values. This is -- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ -- in the @some@ package (except on the penultimate type parameter). listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') listsEqual ZS ZS = Just Refl listsEqual (n ::$ sh) (m ::$ sh') | Just Refl <- testEquality n m , n == m , Just Refl <- listsEqual sh sh' = Just Refl listsEqual _ _ = Nothing listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m listsFold _ ZS = mempty listsFold f (x ::$ xs) = f x <> listsFold f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListS sh' f -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsLength :: ListS sh f -> Int listsLength = getSum . listsFold (\_ -> Sum 1) listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) listsToList :: ListS sh (Const i) -> [i] listsToList ZS = [] listsToList (Const i ::$ is) = i : listsToList is listsHead :: ListS (n : sh) f -> f n listsHead (i ::$ _) = i listsTail :: ListS (n : sh) f -> ListS sh f listsTail (_ ::$ sh) = sh listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh listsInit (_ ::$ ZS) = ZS listsLast :: ListS (n : sh) f -> f (Last (n : sh)) listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh listsLast (n ::$ ZS) = n listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex (Proxy @is') (Proxy @sh) i sh of (item, SNat) -> item ::$ listsPermute is sh -- TODO: remove this SNat when the KnownNat constaint in ListS is removed listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) listsIndex _ _ SZ (n ::$ _) = (n, SNat) listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- * Shaped indices -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) deriving (Eq, Ord, Generic) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS -- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be -- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) where i :.$ IxS shl = IxS (Const i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l #endif instance Functor (IxS sh) where fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where foldMap f (IxS l) = listsFold (f . getConst) l instance NFData i => NFData (IxS sh i) ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank (IxS l) = listsRank l ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) -- TODO: this takes a ShS because there are KnownNats inside IxS. ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i ixsCast ZSS ZIS = ZIS ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx ixsCast _ _ = error "ixsCast: ranks don't match" ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) -- * Shaped shapes -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) deriving (Generic) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS pattern (:$$) :: forall {sh1}. forall n sh. (KnownNat n, n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) where i :$$ ShS shl = ShS (i ::$ shl) infixr 3 :$$ {-# COMPLETE ZSS, (:$$) #-} #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show (ShS sh) #else instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l #endif instance NFData (ShS sh) where rnf (ShS ZS) = () rnf (ShS (SNat ::$ l)) = rnf (ShS l) instance TestEquality ShS where testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int shsLength (ShS l) = listsLength l shsRank :: ShS sh -> SNat (Rank sh) shsRank (ShS l) = listsRank l shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list shsTail :: ShS (n : sh) -> ShS sh shsTail (ShS list) = ShS (listsTail list) shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) shsInit (ShS list) = ShS (listsInit list) shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) shsLast (ShS list) = listsLast list shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') shsAppend = coerce (listsAppend @_ @SNat) shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLen = coerce (listsTakeLenPerm @SNat) shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) shsPermutePrefix = coerce (listsPermutePrefix @SNat) type family Product sh where Product '[] = 1 Product (n : ns) = n * Product ns shsProduct :: ShS sh -> SNat (Product sh) shsProduct ZSS = SNat shsProduct (n :$$ sh) = n `snatMul` shsProduct sh -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShS :: [Nat] -> Constraint class KnownShS sh where knownShS :: ShS sh instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict -- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. -- This function may be removed in a future release. shsFromListS :: ListS sh f -> ShS sh shsFromListS ZS = ZSS shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l -- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This -- function may be removed in a future release. shsFromIxS :: IxS sh i -> ShS sh shsFromIxS (IxS l) = shsFromListS l shsEnum :: ShS sh -> [IIxS sh] shsEnum = shsEnum' {-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site shsEnum' :: Num i => ShS sh -> [IxS sh i] shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] where suffixes = drop 1 (scanr (*) 1 (shsToList sh)) fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i fromLin ZSS _ _ = ZIS fromLin (_ :$$ sh') (I# suff# : suffs) i# = let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh' in fromIntegral (I# q#) :.$ fromLin sh' suffs r# fromLin _ _ _ = error "impossible" -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i fromList topl = go (knownShS @sh) topl where go :: ShS sh' -> [i] -> ListS sh' (Const i) go ZSS [] = ZS go (_ :$$ sh) (i : is) = Const i ::$ go sh is go _ _ = error $ "IsList(ListS): Mismatched list length (type says " ++ show (shsLength (knownShS @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShS sh => IsList (IxS sh i) where type Item (IxS sh i) = i fromList = IxS . IsList.fromList toList = Foldable.toList -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList topl = ShS (go (knownShS @sh) topl) where go :: ShS sh' -> [Int] -> ListS sh' SNat go ZSS [] = ZS go (sn :$$ sh) (i : is) | i == fromSNat' sn = sn ::$ go sh is | otherwise = error $ "IsList(ShS): Value does not match typing (type says " ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" go _ _ = error $ "IsList(ShS): Mismatched list length (type says " ++ show (shsLength (knownShS @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = shsToList