diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 10:31:46 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 10:33:54 +0200 |
commit | 018ebecade82009a3410f19982dd435b6e0715d8 (patch) | |
tree | 286688dcba6963211705c135a032ddd82df4cf88 /src/Data/Array/Nested/Internal.hs | |
parent | 3e74b0673caba7c04353c0cedb1d6e02de1fd007 (diff) |
Rename inductive naturals to INat
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 72 |
1 files changed, 37 insertions, 35 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 41fb1fd..15d72f0 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -21,7 +21,7 @@ TODO: * We should be more consistent in whether functions take a 'StaticShapeX' argument or a 'KnownShapeX' constraint. -* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point +* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point being that we need to do induction over the former, but the latter need to be able to get large. @@ -38,11 +38,11 @@ import Data.Type.Equality import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) -import qualified GHC.TypeLits as GHC +import GHC.TypeLits import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X -import Data.Nat +import Data.INat type family Replicate n a where @@ -53,25 +53,25 @@ type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) +lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) where - go :: SNat m -> StaticShapeX (Replicate m Nothing) + go :: SINat m -> StaticShapeX (Replicate m Nothing) go SZ = SZX go (SS n) = () :$? go n -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n -lemRankReplicate _ = go (knownNat @n) +lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (inatSing @n) where - go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m + go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl go (SS n) | Refl <- go n = Refl -lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (knownNat @n) +lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a + -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (inatSing @n) where - go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS n) | Refl <- go n = Refl @@ -93,7 +93,7 @@ newtype Primitive a = Primitive a -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. -type Mixed :: [Maybe GHC.Nat] -> Type -> Type +type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) @@ -117,7 +117,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -- | Internal helper data family mirrorring 'Mixed' that consists of mutable -- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type data family MixedVecs s sh a newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) @@ -315,7 +315,7 @@ mgenerate sh f where checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool checkBounds IZX SZX = True - checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh' + checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a @@ -325,29 +325,31 @@ mtranspose perm = -- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. +-- represented on the type level as a 'INat'. -- -- Valid elements of a ranked arrays are described by the 'Elt' type class. -- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are -- supported (and are represented as a single, flattened, struct-of-arrays -- array internally). -- --- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a +-- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a -- type-level natural that supports induction. -- -- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type +type Ranked :: INat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. -- -- Like for 'Ranked', the valid elements are described by the 'Elt' type class, -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [GHC.Nat] -> Type -> Type +type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) @@ -364,7 +366,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. -instance (KnownNat n, Elt a) => Elt (Ranked n a) where +instance (KnownINat n, Elt a) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) @@ -431,14 +433,14 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where -- | The shape of a shape-typed array given as a list of 'SNat' values. data SShape sh where ShNil :: SShape '[] - ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh) + ShCons :: SNat n -> SShape sh -> SShape (n : sh) deriving instance Show (SShape sh) infixr 5 `ShCons` -- | A statically-known shape of a shape-typed array. class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil -instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape +instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons natSing knownShape sshapeKnown :: SShape sh -> Dict KnownShape sh sshapeKnown ShNil = Dict @@ -531,7 +533,7 @@ rewriteMixed Refl x = x -- ====== API OF RANKED ARRAYS ====== -- -- | An index into a rank-typed array. -type IxR :: Nat -> Type +type IxR :: INat -> Type data IxR n where IZR :: IxR Z (:::) :: Int -> IxR n -> IxR (S n) @@ -547,7 +549,7 @@ ixCvtRX IZR = IZX ixCvtRX (n ::: idx) = n ::? ixCvtRX idx -rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n +rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) @@ -556,19 +558,19 @@ rshape (Ranked arr) rindex :: Elt a => Ranked n a -> IxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a +rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) (ixCvtRX idx)) -rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a +rgenerate :: forall n a. (KnownINat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a rgenerate sh f | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) -rlift :: forall n1 n2 a. (KnownNat n2, Elt a) +rlift :: forall n1 n2 a. (KnownINat n2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) @@ -576,7 +578,7 @@ rlift f (Ranked arr) = Ranked (mlift f arr) rsumOuter1 :: forall n a. - (Storable a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) + (Storable a, Num a, KnownINat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) => Ranked (S n) a -> Ranked n a rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) @@ -586,7 +588,7 @@ rsumOuter1 (Ranked arr) . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) $ arr -rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a +rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mtranspose perm arr) @@ -600,7 +602,7 @@ rtranspose perm (Ranked arr) -- (traditionally called \"@Fin@\"). Note that because the shape of a -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. -type IxS :: [GHC.Nat] -> Type +type IxS :: [Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) @@ -608,7 +610,7 @@ infixr 5 ::$ cvtSShapeIxS :: SShape sh -> IxS sh cvtSShapeIxS ShNil = IZS -cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh +cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh ixCvtXS ShNil IZX = IZS @@ -644,13 +646,13 @@ slift f (Shaped arr) = Shaped (mlift f arr) ssumOuter1 :: forall sh n a. - (Storable a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) + (Storable a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) - . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh)) + . X.sumOuter (natSing @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr |