{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Ranked.Shape where import Control.DeepSeq (NFData(..)) import Data.Array.Mixed.Types import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Nested.Mixed.Shape type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where ZR :: ListR 0 i (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) deriving instance Functor (ListR n) deriving instance Foldable (ListR n) infixr 3 ::: instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows instance NFData i => NFData (ListR n i) where rnf ZR = () rnf (x ::: l) = rnf x `seq` rnf l data UnconsListRRes i n1 = forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) listrUncons ZR = Nothing -- | This checks only whether the ranks are equal, not whether the actual -- values are. listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') listrEqRank ZR ZR = Just Refl listrEqRank (_ ::: sh) (_ ::: sh') | Just Refl <- listrEqRank sh sh' = Just Refl listrEqRank _ _ = Nothing -- | This compares the lists for value equality. listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') listrEqual ZR ZR = Just Refl listrEqual (i ::: sh) (j ::: sh') | Just Refl <- listrEqual sh sh' , i == j = Just Refl listrEqual _ _ = Nothing listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS listrShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListR n' i -> ShowS go _ ZR = id go prefix (x ::: xs) = showString prefix . f x . go "," xs listrLength :: ListR n i -> Int listrLength = length listrRank :: ListR n i -> SNat n listrRank ZR = SNat listrRank (_ ::: sh) = snatSucc (listrRank sh) listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh listrFromList :: [i] -> (forall n. ListR n i -> r) -> r listrFromList [] k = k ZR listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i listrHead ZR = error "unreachable" listrTail :: ListR (n + 1) i -> ListR n i listrTail (_ ::: sh) = sh listrTail ZR = error "unreachable" listrInit :: ListR (n + 1) i -> ListR n i listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh listrInit (_ ::: ZR) = ZR listrInit ZR = error "unreachable" listrLast :: ListR (n + 1) i -> i listrLast (_ ::: sh@(_ ::: _)) = listrLast sh listrLast (n ::: ZR) = n listrLast ZR = error "unreachable" listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs listrIndex _ ZR = error "k + 1 <= 0" listrZip :: ListR n i -> ListR n j -> ListR n (i, j) listrZip ZR ZR = ZR listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest listrZip _ _ = error "listrZip: impossible pattern needlessly required" listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k listrZipWith _ ZR ZR = ZR listrZipWith f (i ::: irest) (j ::: jrest) = f i j ::: listrZipWith f irest jrest listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> listrFromList perm $ \sperm -> case (listrRank sperm, listrRank sh) of (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) listrSplitAt SS{} ZR = error "m' + 1 <= 0" applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> case cmpNat (SNat @(idx + 1)) sm of LTI -> listrIndex si l ::: applyPermRFull sm perm l EQI -> listrIndex si l ::: applyPermRFull sm perm l GTI -> error "listrPermutePrefix: Index in permutation out of range" -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) deriving (Eq, Ord, Generic) deriving newtype (Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR pattern (:.:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) where i :.: IxR sh = IxR (i ::: sh) infixr 3 :.: {-# COMPLETE ZIR, (:.:) #-} type IIxR n = IxR n Int instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l instance NFData i => NFData (IxR sh i) ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l ixrRank :: IxR n i -> SNat n ixrRank (IxR sh) = listrRank sh ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n ixCvtXR :: IIxX sh -> IIxR (Rank sh) ixCvtXR ZIX = ZIR ixCvtXR (n :.% idx) = n :.: ixCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.% ixCvtRX idx) ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list ixrTail :: IxR (n + 1) i -> IxR n i ixrTail (IxR list) = IxR (listrTail list) ixrInit :: IxR (n + 1) i -> IxR n i ixrInit (IxR list) = IxR (listrInit list) ixrLast :: IxR (n + 1) i -> i ixrLast (IxR list) = listrLast list ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) deriving (Eq, Ord, Generic) deriving newtype (Functor, Foldable) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR = ShR ZR pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) where i :$: ShR sh = ShR (i ::: sh) infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = listrShow shows l instance NFData i => NFData (ShR sh i) shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n shCvtXR' ZSX = castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) ZSR shCvtXR' (n :$% (idx :: IShX sh)) | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = castWith (subst2 (lem1 @sh Refl)) (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) where lem1 :: forall sh' n' k. k : sh' :~: Replicate n' Nothing -> Rank sh' + 1 :~: n' lem1 Refl = unsafeCoerceRefl lem2 :: k : sh :~: Replicate n Nothing -> sh :~: Replicate (Rank sh) Nothing lem2 Refl = unsafeCoerceRefl shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shCvtRX idx) -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' shrLength :: ShR sh i -> Int shrLength (ShR l) = listrLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. shrRank :: ShR n i -> SNat n shrRank (ShR sh) = listrRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int shrSize ZSR = 1 shrSize (n :$: sh) = n * shrSize sh shrHead :: ShR (n + 1) i -> i shrHead (ShR list) = listrHead list shrTail :: ShR (n + 1) i -> ShR n i shrTail (ShR list) = ShR (listrTail list) shrInit :: ShR (n + 1) i -> ShR n i shrInit (ShR list) = ShR (listrInit list) shrLast :: ShR (n + 1) i -> i shrLast (ShR list) = listrLast list shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) shrZip :: ShR n i -> ShR n j -> ShR n (i, j) shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i shrPermutePrefix = coerce (listrPermutePrefix @i) -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ListR n i) where type Item (ListR n i) = i fromList topl = go (SNat @n) topl where go :: SNat n' -> [i] -> ListR n' i go SZ [] = ZR go (SS n) (i : is) = i ::: go n is go _ _ = error $ "IsList(ListR): Mismatched list length (type says " ++ show (fromSNat (SNat @n)) ++ ", list has length " ++ show (length topl) ++ ")" toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IxR n i) where type Item (IxR n i) = i fromList = IxR . IsList.fromList toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList