aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs467
1 files changed, 467 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
new file mode 100644
index 0000000..319f171
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -0,0 +1,467 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Internal.Shape where
+
+import Data.Array.Mixed.Types
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Kind (Type, Constraint)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR sh' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrToSNat :: ListR n i -> SNat n
+listrToSNat ZR = SNat
+listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrToSNat sperm, listrToSNat sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+type IIxR n = IxR n Int
+
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
+ixCvtXR ZIX = ZIR
+ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
+
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
+ixCvtRX ZIR = ZIX
+ixCvtRX (n :.: (idx :: IxR m Int)) =
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixCvtRX idx)
+
+ixrToSNat :: IxR n i -> SNat n
+ixrToSNat (IxR sh) = listrToSNat sh
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: (ShR sh) = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+
+shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
+shCvtXR' ZSX =
+ castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
+ ZSR
+shCvtXR' (n :$% (idx :: IShX sh))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
+ castWith (subst2 (lem1 @sh Refl))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ where
+ lem1 :: forall sh' n' k.
+ k : sh' :~: Replicate n' Nothing
+ -> Rank sh' + 1 :~: n'
+ lem1 Refl = unsafeCoerceRefl
+
+ lem2 :: k : sh :~: Replicate n Nothing
+ -> sh :~: Replicate (Rank sh) Nothing
+ lem2 Refl = unsafeCoerceRefl
+
+shCvtRX :: IShR n -> IShX (Replicate n Nothing)
+shCvtRX ZSR = ZSX
+shCvtRX (n :$: (idx :: ShR m Int)) =
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shCvtRX idx)
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrToSNat :: ShR n i -> SNat n
+shrToSNat (ShR sh) = listrToSNat sh
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList = go (SNat @n)
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error "IsList(ListR): Mismatched list length"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList
+
+
+type role ListS nominal representational
+type ListS :: [Nat] -> (Nat -> Type) -> Type
+data ListS sh f where
+ ZS :: ListS '[] f
+ -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
+ (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+infixr 3 ::$
+
+instance (forall n. Show (f n)) => Show (ListS sh f) where
+ showsPrec _ = listsShow shows
+
+data UnconsListSRes f sh1 =
+ forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
+listsUncons ZS = Nothing
+
+listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
+listsFmap _ ZS = ZS
+listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
+
+listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFold _ ZS = mempty
+listsFold f (x ::$ xs) = f x <> listsFold f xs
+
+listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListS sh' f -> ShowS
+ go _ ZS = id
+ go prefix (x ::$ xs) = showString prefix . f x . go "," xs
+
+listsToList :: ListS sh (Const i) -> [i]
+listsToList ZS = []
+listsToList (Const i ::$ is) = i : listsToList is
+
+listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend ZS idx' = idx'
+listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+
+listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm PNil _ = ZS
+listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
+listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm PNil sh = sh
+listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
+listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
+ case listsIndex (Proxy @is') (Proxy @sh) i sh of
+ (item, SNat) -> item ::$ listsPermute is sh
+
+-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
+listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listsIndex p pT i sh
+listsIndex _ _ _ ZS = error "Index into empty shape"
+
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen = coerce (listsTakeLenPerm @SNat)
+
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute = coerce (listsPermute @SNat)
+
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+
+applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+
+applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
+applyPermIxS = coerce (applyPermS @(Const i))
+
+applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
+applyPermShS = coerce (applyPermS @SNat)
+
+
+-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (ListS sh (Const i))
+ deriving (Eq, Ord)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS = IxS ZS
+
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
+ where i :.$ IxS shl = IxS (Const i ::$ shl)
+infixr 3 :.$
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+type IIxS sh = IxS sh Int
+
+instance Show i => Show (IxS sh i) where
+ showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+
+instance Functor (IxS sh) where
+ fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
+
+instance Foldable (IxS sh) where
+ foldMap f (IxS l) = listsFold (f . getConst) l
+
+ixsZero :: ShS sh -> IIxS sh
+ixsZero ZSS = ZIS
+ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
+
+ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
+ixCvtXS ZSS ZIX = ZIS
+ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
+
+ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
+ixCvtSX ZIS = ZIX
+ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
+
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ListS sh SNat)
+ deriving (Eq, Ord)
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS = ShS ZS
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
+ where i :$$ ShS shl = ShS (i ::$ shl)
+
+infixr 3 :$$
+
+{-# COMPLETE ZSS, (:$$) #-}
+
+instance Show (ShS sh) where
+ showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+
+shsLength :: ShS sh -> Int
+shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)
+
+shsToList :: ShS sh -> [Int]
+shsToList ZSS = []
+shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+
+shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
+shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerceRefl
+shCvtXS' (SUnknown _ :$% _) = error "impossible"
+
+shCvtSX :: ShS sh -> IShX (MapJust sh)
+shCvtSX ZSS = ZSX
+shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
+
+shsSize :: ShS sh -> Int
+shsSize ZSS = 1
+shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownShS sh => IsList (ListS sh (Const i)) where
+ type Item (ListS sh (Const i)) = i
+ fromList topl = go (knownShS @sh) topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listsToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShS sh => IsList (IxS sh i) where
+ type Item (IxS sh i) = i
+ fromList = IxS . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and values are checked at runtime.
+instance KnownShS sh => IsList (ShS sh) where
+ type Item (ShS sh) = Int
+ fromList topl = ShS (go (knownShS @sh) topl)
+ where
+ go :: ShS sh' -> [Int] -> ListS sh' SNat
+ go ZSS [] = ZS
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = sn ::$ go sh is
+ | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shsToList