{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) import Control.Exception (assert) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Constraint, Type) import Data.Proxy import Data.Type.Equality import GHC.Exts (build, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -- * Shaped indices -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (IxX (MapJust sh) i) deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS <- IxS (matchZIX -> Just Refl) where ZIS = IxS ZIX matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[]) matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl matchZIX _ = Nothing pattern (:.$) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l)) where i :.$ IxS l = IxS (i :.% l) infixr 3 :.$ data UnconsIxSRes i sh1 = forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i) ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1) ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1) , Refl <- lemMapJustCons @sh1 Refl = Just (UnconsIxSRes i (IxS l)) ixsUncons (IxS _) = Nothing {-# COMPLETE ZIS, (:.$) #-} -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where showsPrec _ l = ixsShow shows l #endif ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS ixsShow f l = showString "[" . go "" l . showString "]" where go :: String -> IxS sh' i -> ShowS go _ ZIS = id go prefix (x :.$ xs) = showString prefix . f x . go "," xs ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank ZIS = SNat ixsRank (_ :.$ sh) = snatSucc (ixsRank sh) {-# INLINE ixsFromList #-} ixsFromList :: ShS sh -> [i] -> IxS sh i ixsFromList sh l = assert (shsLength sh == length l) $ IxS $ IsList.fromList l {-# INLINE ixsFromIxS #-} ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i ixsFromIxS sh l = assert (length sh == length l) $ IxS $ IsList.fromList l ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i ixsHead (i :.$ _) = i ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (_ :.$ sh) = sh ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh ixsInit (_ :.$ ZIS) = ZIS ixsLast :: IxS (n : sh) i -> i ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh ixsLast (n :.$ ZIS) = n ixsCast :: IxS sh i -> IxS sh i ixsCast ZIS = ZIS ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ coerce (listxAppend @_ @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js {-# INLINE ixsZipWith #-} ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i ixsTakeLenPerm PNil _ = ZIS ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i ixsDropLenPerm PNil sh = sh ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i ixsPermute PNil _ = ZIS ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) = case ixsIndex i sh of item -> item :.$ ixsPermute is sh ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j ixsIndex SZ (n :.$ _) = n ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh ixsIndex _ ZIS = error "Index into empty shape" ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh) -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. {-# INLINEABLE ixsToLinear #-} ixsToLinear :: Num i => ShS sh -> IxS sh i -> i ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix) ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i ixxFromIxS = coerce {-# INLINEABLE ixsFromLinear #-} ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i ixsFromIxX = coerce shsEnum :: ShS sh -> [IIxS sh] shsEnum = shsEnum' {-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site shsEnum' :: Num i => ShS sh -> [IxS sh i] shsEnum' (ShS sh) = (coerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh -- * Shaped shapes -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ShX (MapJust sh) Int) deriving (NFData) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS <- ShS (matchZSX -> Just Refl) where ZSS = ShS ZSX matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl matchZSX _ = Nothing pattern (:$$) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh)) where i :$$ ShS sh = ShS (SKnown i :$% sh) infixr 3 :$$ data UnconsShSRes sh1 = forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) shsUncons (ShS (SKnown x :$% sh')) | Refl <- lemMapJustCons @sh1 Refl = Just (UnconsShSRes x (ShS sh')) shsUncons (ShS _) = Nothing {-# COMPLETE ZSS, (:$$) #-} #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show (ShS sh) #else instance Show (ShS sh) where showsPrec d (ShS shx) = showsPrec d shx #endif instance TestEquality ShS where testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of Nothing -> Nothing Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int shsLength (ShS shx) = shxLength shx shsRank :: forall sh. ShS sh -> SNat (Rank sh) shsRank (ShS shx) | Refl <- lemRankMapJust (Proxy @sh) = shxRank shx lemRankMapJust :: proxy sh -> Rank (MapJust sh) :~: Rank sh lemRankMapJust _ = unsafeCoerceRefl shsSize :: ShS sh -> Int shsSize (ShS sh) = shxSize sh -- | This is a partial @const@ that fails when the second argument -- doesn't match the first. We don't report the size of the list -- in case of errors in order not to retain the list. {-# INLINEABLE shsFromList #-} shsFromList :: ShS sh -> [Int] -> ShS sh shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0 where go :: ShX sh' Int -> [Int] -> () go ZSX [] = () go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")" go (ConsKnown sn sh) (i : is) | i == fromSNat' sn = go sh is | otherwise = error "shsFromList: Value does not match typing" go ConsUnknown{} _ = error "shsFromList: impossible case" go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")" -- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) -> let go :: ShX sh Int -> is go ZSX = nil go ConsUnknown{} = error "shsToList: impossible case" go (ConsKnown snat rest) = fromSNat' snat `cons` go rest in go l) shsHead :: ShS (n : sh) -> SNat n shsHead (ShS shx) = case shxHead shx of SKnown SNat -> SNat shsTail :: forall n sh. ShS (n : sh) -> ShS sh shsTail = coerce (shxTail @_ @_ @Int) {-# INLINEABLE shsTakeIx #-} shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh shsTakeIx _ ZIS _ = ZSS shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh' {-# INLINEABLE shsDropIx #-} shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh' shsDropIx ZIS long = long shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long' shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) shsInit = gcastWith (unsafeCoerceRefl :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ coerce (shxInit @Int) shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) shsLast (ShS shx) = gcastWith (unsafeCoerceRefl :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ case shxLast shx of SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') shsAppend = gcastWith (unsafeCoerceRefl :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ coerce (shxAppend @_ @Int) shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLenPerm = gcastWith (unsafeCoerceRefl :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ coerce (shxTakeLenPerm @Int) shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) shsDropLenPerm = gcastWith (unsafeCoerceRefl :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ coerce (shxDropLenPerm @Int) shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = gcastWith (unsafeCoerceRefl :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ coerce (shxPermute @Int) shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) shsIndex i (ShS sh) = gcastWith (unsafeCoerceRefl :: Index i (MapJust sh) :~: Just (Index i sh)) $ case shxIndex @Int i sh of SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) shsPermutePrefix perm (ShS shx) {- TODO: here and elsewhere, solve the module dependency cycle and add this: | Refl <- lemTakeLenMapJust perm sh , Refl <- lemDropLenMapJust perm sh , Refl <- lemPermuteMapJust perm sh , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm sh)) (shsDropLenPerm perm sh) -} = gcastWith (unsafeCoerceRefl :: Permute is (TakeLen is (MapJust sh)) ++ DropLen is (MapJust sh) :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 Product (n : ns) = n * Product ns shsProduct :: ShS sh -> SNat (Product sh) shsProduct ZSS = SNat shsProduct (n :$$ sh) = n `snatMul` shsProduct sh -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShS :: [Nat] -> Constraint class KnownShS sh where knownShS :: ShS sh instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShS sh => IsList (IxS sh i) where type Item (IxS sh i) = i fromList = ixsFromList (knownShS @sh) toList = Foldable.toList -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList = shsFromList (knownShS @sh) toList = shsToList