aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs363
1 files changed, 363 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
new file mode 100644
index 0000000..1c0b9eb
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -0,0 +1,363 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Array.Mixed.Types
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Kind (Type)
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Mixed.Shape
+
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+
+instance NFData i => NFData (ListR n i) where
+ rnf ZR = ()
+ rnf (x ::: l) = rnf x `seq` rnf l
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqRank ZR ZR = Just Refl
+listrEqRank (_ ::: sh) (_ ::: sh')
+ | Just Refl <- listrEqRank sh sh'
+ = Just Refl
+listrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqual ZR ZR = Just Refl
+listrEqual (i ::: sh) (j ::: sh')
+ | Just Refl <- listrEqual sh sh'
+ , i == j
+ = Just Refl
+listrEqual _ _ = Nothing
+
+listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR n' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrLength :: ListR n i -> Int
+listrLength = length
+
+listrRank :: ListR n i -> SNat n
+listrRank ZR = SNat
+listrRank (_ ::: sh) = snatSucc (listrRank sh)
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrHead :: ListR (n + 1) i -> i
+listrHead (i ::: _) = i
+listrHead ZR = error "unreachable"
+
+listrTail :: ListR (n + 1) i -> ListR n i
+listrTail (_ ::: sh) = sh
+listrTail ZR = error "unreachable"
+
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
+listrZip ZR ZR = ZR
+listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
+listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+
+listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
+listrZipWith _ ZR ZR = ZR
+listrZipWith f (i ::: irest) (j ::: jrest) =
+ f i j ::: listrZipWith f irest jrest
+listrZipWith _ _ _ =
+ error "listrZipWith: impossible pattern needlessly required"
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrRank sperm, listrRank sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+type IIxR n = IxR n Int
+
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+
+instance NFData i => NFData (IxR sh i)
+
+ixrLength :: IxR sh i -> Int
+ixrLength (IxR l) = listrLength l
+
+ixrRank :: IxR n i -> SNat n
+ixrRank (IxR sh) = listrRank sh
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
+ixCvtXR ZIX = ZIR
+ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
+
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
+ixCvtRX ZIR = ZIX
+ixCvtRX (n :.: (idx :: IxR m Int)) =
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixCvtRX idx)
+
+ixrHead :: IxR (n + 1) i -> i
+ixrHead (IxR list) = listrHead list
+
+ixrTail :: IxR (n + 1) i -> IxR n i
+ixrTail (IxR list) = IxR (listrTail list)
+
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
+ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
+ixrAppend = coerce (listrAppend @_ @i)
+
+ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
+ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+
+ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
+ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: ShR sh = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+
+instance NFData i => NFData (ShR sh i)
+
+shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
+shCvtXR' ZSX =
+ castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
+ ZSR
+shCvtXR' (n :$% (idx :: IShX sh))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
+ castWith (subst2 (lem1 @sh Refl))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ where
+ lem1 :: forall sh' n' k.
+ k : sh' :~: Replicate n' Nothing
+ -> Rank sh' + 1 :~: n'
+ lem1 Refl = unsafeCoerceRefl
+
+ lem2 :: k : sh :~: Replicate n Nothing
+ -> sh :~: Replicate (Rank sh) Nothing
+ lem2 Refl = unsafeCoerceRefl
+
+shCvtRX :: IShR n -> IShX (Replicate n Nothing)
+shCvtRX ZSR = ZSX
+shCvtRX (n :$: (idx :: ShR m Int)) =
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shCvtRX idx)
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+
+-- | This compares the shapes for value equality.
+shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+
+shrLength :: ShR sh i -> Int
+shrLength (ShR l) = listrLength l
+
+-- | This function can also be used to conjure up a 'KnownNat' dictionary;
+-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
+-- synonym yields 'KnownNat' evidence.
+shrRank :: ShR n i -> SNat n
+shrRank (ShR sh) = listrRank sh
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrHead :: ShR (n + 1) i -> i
+shrHead (ShR list) = listrHead list
+
+shrTail :: ShR (n + 1) i -> ShR n i
+shrTail (ShR list) = ShR (listrTail list)
+
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
+shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
+shrAppend = coerce (listrAppend @_ @i)
+
+shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
+shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+
+shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
+shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList topl = go (SNat @n) topl
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
+ ++ show (fromSNat (SNat @n)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList