{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Ranked.Shape where import Control.DeepSeq (NFData(..)) import Control.Exception (assert) import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality import GHC.Exts (build) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -- * Ranked indices -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (IxX (Replicate n Nothing) i) deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR <- IxR (matchZIX @n -> Just Refl) where ZIR = IxR ZIX matchZIX :: forall n i. IxX (Replicate n Nothing) i -> Maybe (n :~: 0) matchZIX ZIX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl matchZIX _ = Nothing pattern (:.:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i pattern i :.: l <- (ixrUncons -> Just (UnconsIxRRes i l)) where i :.: IxR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = IxR (i :.% l) infixr 3 :.: data UnconsIxRRes i n1 = forall n. (n + 1 ~ n1) => UnconsIxRRes i (IxR n i) ixrUncons :: forall n1 i. IxR n1 i -> Maybe (UnconsIxRRes i n1) ixrUncons (IxR ((:.%) @n @sh i l)) | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl = Just (UnconsIxRRes i (IxR @(Rank sh) l)) ixrUncons (IxR _) = Nothing {-# COMPLETE ZIR, (:.:) #-} -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type IIxR n = IxR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxR n i) #else instance Show i => Show (IxR n i) where showsPrec _ = ixrShow shows #endif -- | This checks only whether the ranks are equal, not whether the actual -- values are. ixrEqRank :: IxR n i -> IxR n' i -> Maybe (n :~: n') ixrEqRank ZIR ZIR = Just Refl ixrEqRank (_ :.: sh) (_ :.: sh') | Just Refl <- ixrEqRank sh sh' = Just Refl ixrEqRank _ _ = Nothing -- | This compares the lists for value equality. ixrEqual :: Eq i => IxR n i -> IxR n' i -> Maybe (n :~: n') ixrEqual ZIR ZIR = Just Refl ixrEqual (i :.: sh) (j :.: sh') | Just Refl <- ixrEqual sh sh' , i == j = Just Refl ixrEqual _ _ = Nothing {-# INLINE ixrShow #-} ixrShow :: forall n i. (i -> ShowS) -> IxR n i -> ShowS ixrShow f l = showString "[" . go "" l . showString "]" where go :: String -> IxR n' i -> ShowS go _ ZIR = id go prefix (x :.: xs) = showString prefix . f x . go "," xs ixrRank :: IxR n i -> SNat n ixrRank ZIR = SNat ixrRank (_ :.: sh) = snatSucc (ixrRank sh) ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n {-# INLINE ixrFromList #-} ixrFromList :: SNat n -> [i] -> IxR n i ixrFromList topsn topl = assert (fromSNat' topsn == length topl) $ IxR $ IsList.fromList topl ixrHead :: IxR (n + 1) i -> i ixrHead (i :.: _) = i ixrTail :: IxR (n + 1) i -> IxR n i ixrTail (_ :.: sh) = sh ixrInit :: IxR (n + 1) i -> IxR n i ixrInit (n :.: sh@(_ :.: _)) = n :.: ixrInit sh ixrInit (_ :.: ZIR) = ZIR ixrLast :: IxR (n + 1) i -> i ixrLast (_ :.: sh@(_ :.: _)) = ixrLast sh ixrLast (n :.: ZIR) = n -- | Performs a runtime check that the lengths are identical. ixrCast :: SNat n' -> IxR n i -> IxR n' i ixrCast SZ ZIR = ZIR ixrCast (SS n) (i :.: l) = i :.: ixrCast n l ixrCast _ _ = error "ixrCast: ranks don't match" -- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $ coerce (listxAppend @_ @_ @i) ixrIndex :: forall k n i. (k + 1 <= n) => SNat k -> IxR n i -> i ixrIndex SZ (x :.: _) = x ixrIndex (SS i) (_ :.: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = ixrIndex i xs ixrIndex _ ZIR = error "k + 1 <= 0" ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip ZIR ZIR = ZIR ixrZip (i :.: irest) (j :.: jrest) = (i, j) :.: ixrZip irest jrest ixrZip _ _ = error "ixrZip: impossible pattern needlessly required" {-# INLINE ixrZipWith #-} ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith _ ZIR ZIR = ZIR ixrZipWith f (i :.: irest) (j :.: jrest) = f i j :.: ixrZipWith f irest jrest ixrZipWith _ _ _ = error "ixrZipWith: impossible pattern needlessly required" ixrSplitAt :: m <= n' => SNat m -> IxR n' i -> (IxR m i, IxR (n' - m) i) ixrSplitAt SZ sh = (ZIR, sh) ixrSplitAt (SS m) (n :.: sh) = (\(pre, post) -> (n :.: pre, post)) (ixrSplitAt m sh) ixrSplitAt SS{} ZIR = error "m' + 1 <= 0" ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i ixrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case ixrRank sh of { shlen@SNat -> let sperm = ixrFromList permlen perm in case cmpNat permlen shlen of LTI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" } where applyPermRFull :: SNat m -> IxR k Int -> IxR m i -> IxR k i applyPermRFull _ ZIR _ = ZIR applyPermRFull sm@SNat (i :.: perm) l = TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> case cmpNat (SNat @(idx + 1)) sm of LTI -> ixrIndex si l :.: applyPermRFull sm perm l EQI -> ixrIndex si l :.: applyPermRFull sm perm l GTI -> error "ixrPermutePrefix: Index in permutation out of range" -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. {-# INLINEABLE ixrToLinear #-} ixrToLinear :: Num i => IShR m -> IxR m i -> i ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix) ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i ixxFromIxR = coerce {-# INLINEABLE ixrFromLinear #-} ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i ixrFromLinear (ShR sh) i | Refl <- lemRankReplicate (Proxy @m) = ixrFromIxX $ ixxFromLinear sh i ixrFromIxX :: IxX (Replicate n Nothing) i -> IxR n i ixrFromIxX = coerce shrEnum :: IShR n -> [IIxR n] shrEnum = shrEnum' {-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site shrEnum' :: forall i n. Num i => IShR n -> [IxR n i] shrEnum' (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = (coerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh -- * Ranked shapes type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ShX (Replicate n Nothing) i) deriving (Eq, Ord, NFData, Functor) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR <- ShR (matchZSR @n -> Just Refl) where ZSR = ShR ZSX matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0) matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl matchZSR _ = Nothing pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes i sh)) where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh) infixr 3 :$: data UnconsShRRes i n1 = forall n. (n + 1 ~ n1) => UnconsShRRes i (ShR n i) shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1) shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl = Just (UnconsShRRes x (ShR sh')) shrUncons (ShR _) = Nothing {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where showsPrec d (ShR l) = showsPrec d l #endif -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') shrEqRank ZSR ZSR = Just Refl shrEqRank (_ :$: sh) (_ :$: sh') | Just Refl <- shrEqRank sh sh' = Just Refl shrEqRank _ _ = Nothing -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') shrEqual ZSR ZSR = Just Refl shrEqual (i :$: sh) (i' :$: sh') | Just Refl <- shrEqual sh sh' , i == i' = Just Refl shrEqual _ _ = Nothing shrLength :: ShR sh i -> Int shrLength (ShR l) = shxLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. shrRank :: forall n i. ShR n i -> SNat n shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int shrSize (ShR sh) = shxSize sh -- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@. -- We don't report the size of the list in case of errors in order not to retain the list. {-# INLINEABLE shrFromList #-} shrFromList :: SNat n -> [Int] -> IShR n shrFromList snat topl = ShR $ go snat topl where go :: SNat n -> [Int] -> ShX (Replicate n Nothing) Int go SZ [] = ZSX go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")" go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is) go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")" -- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shrToList #-} shrToList :: IShR n -> [Int] shrToList (ShR l) = build (\(cons :: i -> is -> is) (nil :: is) -> let go :: ShX sh Int -> is go ZSX = nil go (ConsUnknown i rest) = i `cons` go rest go ConsKnown{} = error "shrToList: impossible case" in go l) shrHead :: forall n i. ShR (n + 1) i -> i shrHead (ShR sh) | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = case shxHead @Nothing @(Replicate n Nothing) sh of SUnknown i -> i shrTail :: forall n i. ShR (n + 1) i -> ShR n i shrTail | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce (shxTail @_ @_ @i) {-# INLINEABLE shrTakeIx #-} shrTakeIx :: forall n n' i j. Proxy n' -> IxR n j -> ShR (n + n') i -> ShR n i shrTakeIx _ ZIR _ = ZSR shrTakeIx p (_ :.: idx) sh = case sh of n :$: sh' -> n :$: shrTakeIx p idx sh' {-# INLINEABLE shrDropIx #-} shrDropIx :: forall n n' i j. IxR n j -> ShR (n + n') i -> ShR n' i shrDropIx ZIR long = long shrDropIx (_ :.: short) long = case long of _ :$: long' -> shrDropIx short long' shrInit :: forall n i. ShR (n + 1) i -> ShR n i shrInit | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = -- TODO: change this and all other unsafeCoerceRefl to lemmas: gcastWith (unsafeCoerceRefl :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ coerce (shxInit @i) shrLast :: forall n i. ShR (n + 1) i -> i shrLast (ShR sh) | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = case shxLast sh of SUnknown i -> i SKnown{} -> error "shrLast: impossible SKnown" -- | Performs a runtime check that the lengths are identical. shrCast :: SNat n' -> ShR n i -> ShR n' i shrCast SZ ZSR = ZSR shrCast (SS n) (i :$: sh) = i :$: shrCast n sh shrCast _ _ = error "shrCast: ranks don't match" shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = -- lemReplicatePlusApp requires an SNat gcastWith (unsafeCoerceRefl :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ coerce (shxAppend @_ @i) {-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k shrZipWith _ ZSR ZSR = ZSR shrZipWith f (i :$: irest) (j :$: jrest) = f i j :$: shrZipWith f irest jrest shrZipWith _ _ _ = error "shrZipWith: impossible pattern needlessly required" shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i) shrSplitAt SZ sh = (ZSR, sh) shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh) shrSplitAt SS{} ZSR = error "m' + 1 <= 0" shrIndex :: forall k sh i. SNat k -> ShR sh i -> i shrIndex k (ShR sh) = case shxIndex @i k sh of SUnknown i -> i SKnown{} -> error "shrIndex: impossible SKnown" -- Copy-pasted from ixrPermutePrefix, probably unavoidably. shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i shrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case shrRank sh of { shlen@SNat -> let sperm = shrFromList permlen perm in case cmpNat permlen shlen of LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" } where applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i applyPermRFull _ ZSR _ = ZSR applyPermRFull sm@SNat (i :$: perm) l = TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> case cmpNat (SNat @(idx + 1)) sm of LTI -> shrIndex si l :$: applyPermRFull sm perm l EQI -> shrIndex si l :$: applyPermRFull sm perm l GTI -> error "shrPermutePrefix: Index in permutation out of range" -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IxR n i) where type Item (IxR n i) = i fromList = ixrFromList (SNat @n) toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IShR n) where type Item (IShR n) = Int fromList = shrFromList (SNat @n) toList = shrToList