diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 80 |
1 files changed, 43 insertions, 37 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 1c0b9eb..8b670e5 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} @@ -27,7 +28,6 @@ module Data.Array.Nested.Ranked.Shape where import Control.DeepSeq (NFData(..)) -import Data.Array.Mixed.Types import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) @@ -39,10 +39,12 @@ import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Types +-- * Ranked lists + type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where @@ -54,8 +56,12 @@ deriving instance Functor (ListR n) deriving instance Foldable (ListR n) infixr 3 ::: +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListR n i) +#else instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows +#endif instance NFData i => NFData (ListR n i) where rnf ZR = () @@ -125,6 +131,10 @@ listrLast (_ ::: sh@(_ ::: _)) = listrLast sh listrLast (n ::: ZR) = n listrLast ZR = error "unreachable" +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" + listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs @@ -167,6 +177,8 @@ listrPermutePrefix = \perm sh -> GTI -> error "listrPermutePrefix: Index in permutation out of range" +-- * Ranked indices + -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type @@ -187,10 +199,16 @@ infixr 3 :.: {-# COMPLETE ZIR, (:.:) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxR n = IxR n Int +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxR n i) +#else instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l +#endif instance NFData i => NFData (IxR sh i) @@ -204,16 +222,6 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -226,6 +234,10 @@ ixrInit (IxR list) = IxR (listrInit list) ixrLast :: IxR (n + 1) i -> i ixrLast (IxR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) @@ -239,6 +251,8 @@ ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- * Ranked shapes + type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) @@ -260,35 +274,15 @@ infixr 3 :$: type IShR n = ShR n Int +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShR n i) +#else instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = listrShow shows l +#endif instance NFData i => NFData (ShR sh i) -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') @@ -324,6 +318,10 @@ shrInit (ShR list) = ShR (listrInit list) shrLast :: ShR (n + 1) i -> i shrLast (ShR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) @@ -361,3 +359,11 @@ instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" |