aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Shape.hs467
1 files changed, 0 insertions, 467 deletions
diff --git a/src/Data/Array/Nested/Shape.hs b/src/Data/Array/Nested/Shape.hs
deleted file mode 100644
index 774b4bd..0000000
--- a/src/Data/Array/Nested/Shape.hs
+++ /dev/null
@@ -1,467 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE RoleAnnotations #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Shape where
-
-import Data.Array.Mixed.Types
-import Data.Coerce (coerce)
-import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Kind (Type, Constraint)
-import Data.Monoid (Sum(..))
-import Data.Proxy
-import Data.Type.Equality
-import GHC.IsList (IsList)
-import GHC.IsList qualified as IsList
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
-infixr 3 :::
-
-instance Show i => Show (ListR n i) where
- showsPrec _ = listrShow shows
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
-listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS
-listrShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR sh' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend ZR sh = sh
-listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-listrToSNat :: ListR n i -> SNat n
-listrToSNat ZR = SNat
-listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat
-
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrToSNat sperm, listrToSNat sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "listrPermutePrefix: Index in permutation out of range"
-
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
- where i :.: IxR sh = IxR (i ::: sh)
-infixr 3 :.:
-
-{-# COMPLETE ZIR, (:.:) #-}
-
-type IIxR n = IxR n Int
-
-instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = listrShow shows l
-
-ixrZero :: SNat n -> IIxR n
-ixrZero SZ = ZIR
-ixrZero (SS n) = 0 :.: ixrZero n
-
-ixCvtXR :: IIxX sh -> IIxR (Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixCvtRX idx)
-
-ixrToSNat :: IxR n i -> SNat n
-ixrToSNat (IxR sh) = listrToSNat sh
-
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
-type role ShR nominal representational
-type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: (ShR sh) = ShR (i ::: sh)
-infixr 3 :$:
-
-{-# COMPLETE ZSR, (:$:) #-}
-
-type IShR n = ShR n Int
-
-instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
-
-shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
-shCvtXR' ZSX =
- castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
- ZSR
-shCvtXR' (n :$% (idx :: IShX sh))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
- castWith (subst2 (lem1 @sh Refl))
- (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
- where
- lem1 :: forall sh' n' k.
- k : sh' :~: Replicate n' Nothing
- -> Rank sh' + 1 :~: n'
- lem1 Refl = unsafeCoerceRefl
-
- lem2 :: k : sh :~: Replicate n Nothing
- -> sh :~: Replicate (Rank sh) Nothing
- lem2 Refl = unsafeCoerceRefl
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shCvtRX idx)
-
--- | The number of elements in an array described by this shape.
-shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
-
-shrToSNat :: ShR n i -> SNat n
-shrToSNat (ShR sh) = listrToSNat sh
-
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList = go (SNat @n)
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error "IsList(ListR): Mismatched list length"
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (IxR n i) where
- type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
-
-
-type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
-infixr 3 ::$
-
-instance (forall n. Show (f n)) => Show (ListS sh f) where
- showsPrec _ = listsShow shows
-
-data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
-listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
-listsUncons ZS = Nothing
-
-listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-listsFmap _ ZS = ZS
-listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-
-listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFold _ ZS = mempty
-listsFold f (x ::$ xs) = f x <> listsFold f xs
-
-listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
-listsShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListS sh' f -> ShowS
- go _ ZS = id
- go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-
-listsToList :: ListS sh (Const i) -> [i]
-listsToList ZS = []
-listsToList (Const i ::$ is) = i : listsToList is
-
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
-listsAppend ZS idx' = idx'
-listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-
-listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
-listsTakeLenPerm PNil _ = ZS
-listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
-listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
-listsDropLenPerm PNil sh = sh
-listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
-listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
-listsPermute PNil _ = ZS
-listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
-
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh
-listsIndex _ _ _ ZS = error "Index into empty shape"
-
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
-
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
-
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
-
-applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
-applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-
-applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-applyPermIxS = coerce (applyPermS @(Const i))
-
-applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-applyPermShS = coerce (applyPermS @SNat)
-
-
--- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\"). Note that because the shape of a
--- shape-typed array is known statically, you can also retrieve the array shape
--- from a 'KnownShape' dictionary.
-type role IxS nominal representational
-type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord)
-
-pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
-
-pattern (:.$)
- :: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
- => i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
-infixr 3 :.$
-
-{-# COMPLETE ZIS, (:.$) #-}
-
-type IIxS sh = IxS sh Int
-
-instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
-
-instance Functor (IxS sh) where
- fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
-
-instance Foldable (IxS sh) where
- foldMap f (IxS l) = listsFold (f . getConst) l
-
-ixsZero :: ShS sh -> IIxS sh
-ixsZero ZSS = ZIS
-ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-
-ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
-
-ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
-ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
-
-
--- | The shape of a shape-typed array given as a list of 'SNat' values.
-type role ShS nominal
-type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Eq, Ord)
-
-pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
-
-pattern (:$$)
- :: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
- => SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
-
-infixr 3 :$$
-
-{-# COMPLETE ZSS, (:$$) #-}
-
-instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
-
-shsLength :: ShS sh -> Int
-shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)
-
-shsToList :: ShS sh -> [Int]
-shsToList ZSS = []
-shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-
-shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
-shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shCvtXS' (SUnknown _ :$% _) = error "impossible"
-
-shCvtSX :: ShS sh -> IShX (MapJust sh)
-shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
-
-shsSize :: ShS sh -> Int
-shsSize ZSS = 1
-shsSize (n :$$ sh) = fromSNat' n * shsSize sh
-
--- | Evidence for the static part of a shape. This pops up only when you are
--- polymorphic in the element type of an array.
-type KnownShS :: [Nat] -> Constraint
-class KnownShS sh where knownShS :: ShS sh
-instance KnownShS '[] where knownShS = ZSS
-instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
-
-
--- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = listsToList
-
--- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
-instance KnownShS sh => IsList (IxS sh i) where
- type Item (IxS sh i) = i
- fromList = IxS . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length and values are checked at runtime.
-instance KnownShS sh => IsList (ShS sh) where
- type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = shsToList