diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 399 |
1 files changed, 399 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs new file mode 100644 index 0000000..60e0252 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -0,0 +1,399 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped.Shape where + +import Control.DeepSeq (NFData(..)) +import Control.Exception (assert) +import Data.Array.Shape qualified as O +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Constraint, Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Exts (build, withDict) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Mixed.ListX +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types + + +-- * Shaped indices + +-- | An index into a shape-typed array. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (IxX (MapJust sh) i) + deriving (Eq, Ord, NFData, Functor, Foldable) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS <- IxS (matchZIX -> Just Refl) + where ZIS = IxS ZIX + +matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZIX _ = Nothing + +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l)) + where i :.$ IxS l = IxS (i :.% l) +infixr 3 :.$ + +data UnconsIxSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i) +ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1) +ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1) + , Refl <- lemMapJustCons @sh1 Refl = + Just (UnconsIxSRes i (IxS l)) +ixsUncons (IxS _) = Nothing + +{-# COMPLETE ZIS, (:.$) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxS sh = IxS sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else +instance Show i => Show (IxS sh i) where + showsPrec _ l = ixsShow shows l +#endif + +ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS +ixsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> IxS sh' i -> ShowS + go _ ZIS = id + go prefix (x :.$ xs) = showString prefix . f x . go "," xs + +ixsRank :: IxS sh i -> SNat (Rank sh) +ixsRank ZIS = SNat +ixsRank (_ :.$ sh) = snatSucc (ixsRank sh) + +{-# INLINE ixsFromList #-} +ixsFromList :: ShS sh -> [i] -> IxS sh i +ixsFromList sh l = assert (shsLength sh == length l) + $ IxS $ IsList.fromList l + +{-# INLINE ixsFromIxS #-} +ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS sh l = assert (length sh == length l) + $ IxS $ IsList.fromList l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixsHead :: IxS (n : sh) i -> i +ixsHead (i :.$ _) = i + +ixsTail :: IxS (n : sh) i -> IxS sh i +ixsTail (_ :.$ sh) = sh + +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh +ixsInit (_ :.$ ZIS) = ZIS + +ixsLast :: IxS (n : sh) i -> i +ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh +ixsLast (n :.$ ZIS) = n + +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx + +ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i +ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ + coerce (ixxAppend @_ @_ @i) + +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) +ixsZip ZIS ZIS = ZIS +ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js + +{-# INLINE ixsZipWith #-} +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k +ixsZipWith _ ZIS ZIS = ZIS +ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js + +ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i +ixsTakeLenPerm PNil _ = ZIS +ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh +ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i +ixsDropLenPerm PNil sh = sh +ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh +ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i +ixsPermute PNil _ = ZIS +ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) = + case ixsIndex i sh of + item -> item :.$ ixsPermute is sh + +ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j +ixsIndex SZ (n :.$ _) = n +ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh +ixsIndex _ ZIS = error "Index into empty shape" + +ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh) + +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS = coerce + +{-# INLINEABLE ixsFromLinear #-} +ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i +ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i + +ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i +ixsFromIxX = coerce + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' (ShS sh) = (coerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh + +-- * Shaped shapes + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +-- +-- Note that because the shape of a shape-typed array is known statically, you +-- can also retrieve the array shape from a 'KnownShS' dictionary. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) + +instance Eq (ShS sh) where _ == _ = True +instance Ord (ShS sh) where compare _ _ = EQ + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS <- ShS (matchZSX -> Just Refl) + where ZSS = ShS ZSX + +matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZSX _ = Nothing + +pattern (:$$) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh)) + where i :$$ ShS sh = ShS (SKnown i :$% sh) +infixr 3 :$$ + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing + +{-# COMPLETE ZSS, (:$$) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else +instance Show (ShS sh) where + showsPrec d (ShS shx) = showsPrec d shx +#endif + +instance TestEquality ShS where + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality + +shsLength :: ShS sh -> Int +shsLength (ShS shx) = shxLength shx + +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) | Refl <- lemRankMapJust (Proxy @sh) = + shxRank shx + +lemRankMapJust :: proxy sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust _ = unsafeCoerceRefl + +shsSize :: ShS sh -> Int +shsSize (ShS sh) = shxSize sh + +-- | This is a partial @const@ that fails when the second argument +-- doesn't match the first. We don't report the size of the list +-- in case of errors in order not to retain the list. +{-# INLINEABLE shsFromList #-} +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0 + where + go :: ShX sh' Int -> [Int] -> () + go ZSX [] = () + go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")" + go (ConsKnown sn sh) (i : is) + | i == fromSNat' sn = go sh is + | otherwise = error "shsFromList: Value does not match typing" + go ConsUnknown{} _ = error "shsFromList: impossible case" + go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")" + +-- This is equivalent to but faster than @coerce shxToList@. +{-# INLINEABLE shsToList #-} +shsToList :: ShS sh -> [Int] +shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil + go ConsUnknown{} = error "shsToList: impossible case" + go (ConsKnown snat rest) = fromSNat' snat `cons` go rest + in go l) + +shsHead :: ShS (n : sh) -> SNat n +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat + +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) + +{-# INLINEABLE shsTakeIx #-} +shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh +shsTakeIx _ ZIS _ = ZSS +shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh' + +{-# INLINEABLE shsDropIx #-} +shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh' +shsDropIx ZIS long = long +shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long' + +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @Int) + +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat + +shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @Int) + +shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLenPerm = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce (shxTakeLenPerm @Int) + +shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLenPerm = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce (shxDropLenPerm @Int) + +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce (shxPermute @Int) + +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i (ShS sh) = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex @Int i sh of + SKnown SNat -> SNat + +shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm sh)) (shsDropLenPerm perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) + +type family Product sh where + Product '[] = 1 + Product (n : ns) = n * Product ns + +shsProduct :: ShS sh -> SNat (Product sh) +shsProduct ZSS = SNat +shsProduct (n :$$ sh) = n `snatMul` shsProduct sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + +withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r +withKnownShS = withDict @(KnownShS sh) + +shsKnownShS :: ShS sh -> Dict KnownShS sh +shsKnownShS ZSS = Dict +shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict + +shsOrthotopeShape :: ShS sh -> Dict O.Shape sh +shsOrthotopeShape ZSS = Dict +shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict + + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = ixsFromList (knownShS @sh) + toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList = shsFromList (knownShS @sh) + toList = shsToList |
