aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs399
1 files changed, 399 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
new file mode 100644
index 0000000..60e0252
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -0,0 +1,399 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Shaped.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Control.Exception (assert)
+import Data.Array.Shape qualified as O
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Kind (Constraint, Type)
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Exts (build, withDict)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+
+import Data.Array.Nested.Mixed.ListX
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+
+
+-- * Shaped indices
+
+-- | An index into a shape-typed array.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (IxX (MapJust sh) i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS <- IxS (matchZIX -> Just Refl)
+ where ZIS = IxS ZIX
+
+matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZIX _ = Nothing
+
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l))
+ where i :.$ IxS l = IxS (i :.% l)
+infixr 3 :.$
+
+data UnconsIxSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i)
+ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1)
+ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1)
+ , Refl <- lemMapJustCons @sh1 Refl =
+ Just (UnconsIxSRes i (IxS l))
+ixsUncons (IxS _) = Nothing
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
+type IIxS sh = IxS sh Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxS sh i)
+#else
+instance Show i => Show (IxS sh i) where
+ showsPrec _ l = ixsShow shows l
+#endif
+
+ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS
+ixsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> IxS sh' i -> ShowS
+ go _ ZIS = id
+ go prefix (x :.$ xs) = showString prefix . f x . go "," xs
+
+ixsRank :: IxS sh i -> SNat (Rank sh)
+ixsRank ZIS = SNat
+ixsRank (_ :.$ sh) = snatSucc (ixsRank sh)
+
+{-# INLINE ixsFromList #-}
+ixsFromList :: ShS sh -> [i] -> IxS sh i
+ixsFromList sh l = assert (shsLength sh == length l)
+ $ IxS $ IsList.fromList l
+
+{-# INLINE ixsFromIxS #-}
+ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS sh l = assert (length sh == length l)
+ $ IxS $ IsList.fromList l
+
+ixsZero :: ShS sh -> IIxS sh
+ixsZero ZSS = ZIS
+ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
+
+ixsHead :: IxS (n : sh) i -> i
+ixsHead (i :.$ _) = i
+
+ixsTail :: IxS (n : sh) i -> IxS sh i
+ixsTail (_ :.$ sh) = sh
+
+ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
+ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh
+ixsInit (_ :.$ ZIS) = ZIS
+
+ixsLast :: IxS (n : sh) i -> i
+ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh
+ixsLast (n :.$ ZIS) = n
+
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
+
+ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
+ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $
+ coerce (ixxAppend @_ @_ @i)
+
+ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
+ixsZip ZIS ZIS = ZIS
+ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
+
+{-# INLINE ixsZipWith #-}
+ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k
+ixsZipWith _ ZIS ZIS = ZIS
+ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
+
+ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i
+ixsTakeLenPerm PNil _ = ZIS
+ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh
+ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape"
+
+ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i
+ixsDropLenPerm PNil sh = sh
+ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh
+ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape"
+
+ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i
+ixsPermute PNil _ = ZIS
+ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) =
+ case ixsIndex i sh of
+ item -> item :.$ ixsPermute is sh
+
+ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j
+ixsIndex SZ (n :.$ _) = n
+ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh
+ixsIndex _ ZIS = error "Index into empty shape"
+
+ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
+ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh)
+
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS = coerce
+
+{-# INLINEABLE ixsFromLinear #-}
+ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i
+ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i
+
+ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX = coerce
+
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' (ShS sh) = (coerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh
+
+-- * Shaped shapes
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+--
+-- Note that because the shape of a shape-typed array is known statically, you
+-- can also retrieve the array shape from a 'KnownShS' dictionary.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ShX (MapJust sh) Int)
+ deriving (NFData)
+
+instance Eq (ShS sh) where _ == _ = True
+instance Ord (ShS sh) where compare _ _ = EQ
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS <- ShS (matchZSX -> Just Refl)
+ where ZSS = ShS ZSX
+
+matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZSX _ = Nothing
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh))
+ where i :$$ ShS sh = ShS (SKnown i :$% sh)
+infixr 3 :$$
+
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
+shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1)
+shsUncons (ShS (SKnown x :$% sh')) | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsShSRes x (ShS sh'))
+shsUncons (ShS _) = Nothing
+
+{-# COMPLETE ZSS, (:$$) #-}
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (ShS sh)
+#else
+instance Show (ShS sh) where
+ showsPrec d (ShS shx) = showsPrec d shx
+#endif
+
+instance TestEquality ShS where
+ testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of
+ Nothing -> Nothing
+ Just Refl -> Just unsafeCoerceRefl
+
+-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
+-- equal if and only if values are equal.)
+shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
+shsEqual = testEquality
+
+shsLength :: ShS sh -> Int
+shsLength (ShS shx) = shxLength shx
+
+shsRank :: forall sh. ShS sh -> SNat (Rank sh)
+shsRank (ShS shx) | Refl <- lemRankMapJust (Proxy @sh) =
+ shxRank shx
+
+lemRankMapJust :: proxy sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust _ = unsafeCoerceRefl
+
+shsSize :: ShS sh -> Int
+shsSize (ShS sh) = shxSize sh
+
+-- | This is a partial @const@ that fails when the second argument
+-- doesn't match the first. We don't report the size of the list
+-- in case of errors in order not to retain the list.
+{-# INLINEABLE shsFromList #-}
+shsFromList :: ShS sh -> [Int] -> ShS sh
+shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0
+ where
+ go :: ShX sh' Int -> [Int] -> ()
+ go ZSX [] = ()
+ go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")"
+ go (ConsKnown sn sh) (i : is)
+ | i == fromSNat' sn = go sh is
+ | otherwise = error "shsFromList: Value does not match typing"
+ go ConsUnknown{} _ = error "shsFromList: impossible case"
+ go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")"
+
+-- This is equivalent to but faster than @coerce shxToList@.
+{-# INLINEABLE shsToList #-}
+shsToList :: ShS sh -> [Int]
+shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ShX sh Int -> is
+ go ZSX = nil
+ go ConsUnknown{} = error "shsToList: impossible case"
+ go (ConsKnown snat rest) = fromSNat' snat `cons` go rest
+ in go l)
+
+shsHead :: ShS (n : sh) -> SNat n
+shsHead (ShS shx) = case shxHead shx of
+ SKnown SNat -> SNat
+
+shsTail :: forall n sh. ShS (n : sh) -> ShS sh
+shsTail = coerce (shxTail @_ @_ @Int)
+
+{-# INLINEABLE shsTakeIx #-}
+shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh
+shsTakeIx _ ZIS _ = ZSS
+shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh'
+
+{-# INLINEABLE shsDropIx #-}
+shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh'
+shsDropIx ZIS long = long
+shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long'
+
+shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
+shsInit =
+ gcastWith (unsafeCoerceRefl
+ :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $
+ coerce (shxInit @Int)
+
+shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $
+ case shxLast shx of
+ SKnown SNat -> SNat
+
+shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
+shsAppend =
+ gcastWith (unsafeCoerceRefl
+ :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $
+ coerce (shxAppend @_ @Int)
+
+shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLenPerm =
+ gcastWith (unsafeCoerceRefl
+ :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $
+ coerce (shxTakeLenPerm @Int)
+
+shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh)
+shsDropLenPerm =
+ gcastWith (unsafeCoerceRefl
+ :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $
+ coerce (shxDropLenPerm @Int)
+
+shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute =
+ gcastWith (unsafeCoerceRefl
+ :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $
+ coerce (shxPermute @Int)
+
+shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex i (ShS sh) =
+ gcastWith (unsafeCoerceRefl
+ :: Index i (MapJust sh) :~: Just (Index i sh)) $
+ case shxIndex @Int i sh of
+ SKnown SNat -> SNat
+
+shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
+shsPermutePrefix perm (ShS shx)
+ {- TODO: here and elsewhere, solve the module dependency cycle and add this:
+ | Refl <- lemTakeLenMapJust perm sh
+ , Refl <- lemDropLenMapJust perm sh
+ , Refl <- lemPermuteMapJust perm sh
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm sh)) (shsDropLenPerm perm sh) -}
+ = gcastWith (unsafeCoerceRefl
+ :: Permute is (TakeLen is (MapJust sh))
+ ++ DropLen is (MapJust sh)
+ :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $
+ ShS (shxPermutePrefix perm shx)
+
+type family Product sh where
+ Product '[] = 1
+ Product (n : ns) = n * Product ns
+
+shsProduct :: ShS sh -> SNat (Product sh)
+shsProduct ZSS = SNat
+shsProduct (n :$$ sh) = n `snatMul` shsProduct sh
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+
+withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
+withKnownShS = withDict @(KnownShS sh)
+
+shsKnownShS :: ShS sh -> Dict KnownShS sh
+shsKnownShS ZSS = Dict
+shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict
+
+shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
+shsOrthotopeShape ZSS = Dict
+shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShS sh => IsList (IxS sh i) where
+ type Item (IxS sh i) = i
+ fromList = ixsFromList (knownShS @sh)
+ toList = Foldable.toList
+
+-- | Untyped: length and values are checked at runtime.
+instance KnownShS sh => IsList (ShS sh) where
+ type Item (ShS sh) = Int
+ fromList = shsFromList (knownShS @sh)
+ toList = shsToList