summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs72
1 files changed, 37 insertions, 35 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 41fb1fd..15d72f0 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -21,7 +21,7 @@ TODO:
* We should be more consistent in whether functions take a 'StaticShapeX'
argument or a 'KnownShapeX' constraint.
-* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point
+* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
being that we need to do induction over the former, but the latter need to be
able to get large.
@@ -38,11 +38,11 @@ import Data.Type.Equality
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
-import qualified GHC.TypeLits as GHC
+import GHC.TypeLits
import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
import qualified Data.Array.Mixed as X
-import Data.Nat
+import Data.INat
type family Replicate n a where
@@ -53,25 +53,25 @@ type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
-lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
+lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
where
- go :: SNat m -> StaticShapeX (Replicate m Nothing)
+ go :: SINat m -> StaticShapeX (Replicate m Nothing)
go SZ = SZX
go (SS n) = () :$? go n
-lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n
-lemRankReplicate _ = go (knownNat @n)
+lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (inatSing @n)
where
- go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m
+ go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
go (SS n) | Refl <- go n = Refl
-lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
- -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (knownNat @n)
+lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (inatSing @n)
where
- go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS n) | Refl <- go n = Refl
@@ -93,7 +93,7 @@ newtype Primitive a = Primitive a
--
-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
-- class.
-type Mixed :: [Maybe GHC.Nat] -> Type -> Type
+type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
@@ -117,7 +117,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-- | Internal helper data family mirrorring 'Mixed' that consists of mutable
-- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
data family MixedVecs s sh a
newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
@@ -315,7 +315,7 @@ mgenerate sh f
where
checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
checkBounds IZX SZX = True
- checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh'
+ checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
@@ -325,29 +325,31 @@ mtranspose perm =
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'Nat'.
+-- represented on the type level as a 'INat'.
--
-- Valid elements of a ranked arrays are described by the 'Elt' type class.
-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
-- supported (and are represented as a single, flattened, struct-of-arrays
-- array internally).
--
--- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a
+-- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a
-- type-level natural that supports induction.
--
-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: Nat -> Type -> Type
+type Ranked :: INat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
-- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's.
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
--
-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
-- and 'Shaped' itself is again an instance of 'Elt' as well.
--
-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [GHC.Nat] -> Type -> Type
+type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
@@ -364,7 +366,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
-instance (KnownNat n, Elt a) => Elt (Ranked n a) where
+instance (KnownINat n, Elt a) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
@@ -431,14 +433,14 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where
-- | The shape of a shape-typed array given as a list of 'SNat' values.
data SShape sh where
ShNil :: SShape '[]
- ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh)
+ ShCons :: SNat n -> SShape sh -> SShape (n : sh)
deriving instance Show (SShape sh)
infixr 5 `ShCons`
-- | A statically-known shape of a shape-typed array.
class KnownShape sh where knownShape :: SShape sh
instance KnownShape '[] where knownShape = ShNil
-instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape
+instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons natSing knownShape
sshapeKnown :: SShape sh -> Dict KnownShape sh
sshapeKnown ShNil = Dict
@@ -531,7 +533,7 @@ rewriteMixed Refl x = x
-- ====== API OF RANKED ARRAYS ====== --
-- | An index into a rank-typed array.
-type IxR :: Nat -> Type
+type IxR :: INat -> Type
data IxR n where
IZR :: IxR Z
(:::) :: Int -> IxR n -> IxR (S n)
@@ -547,7 +549,7 @@ ixCvtRX IZR = IZX
ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
-rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n
+rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
@@ -556,19 +558,19 @@ rshape (Ranked arr)
rindex :: Elt a => Ranked n a -> IxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
(ixCvtRX idx))
-rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate :: forall n a. (KnownINat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a
rgenerate sh f
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
-rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
+rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
@@ -576,7 +578,7 @@ rlift f (Ranked arr)
= Ranked (mlift f arr)
rsumOuter1 :: forall n a.
- (Storable a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
+ (Storable a, Num a, KnownINat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
=> Ranked (S n) a -> Ranked n a
rsumOuter1 (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
@@ -586,7 +588,7 @@ rsumOuter1 (Ranked arr)
. coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
$ arr
-rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mtranspose perm arr)
@@ -600,7 +602,7 @@ rtranspose perm (Ranked arr)
-- (traditionally called \"@Fin@\"). Note that because the shape of a
-- shape-typed array is known statically, you can also retrieve the array shape
-- from a 'KnownShape' dictionary.
-type IxS :: [GHC.Nat] -> Type
+type IxS :: [Nat] -> Type
data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
@@ -608,7 +610,7 @@ infixr 5 ::$
cvtSShapeIxS :: SShape sh -> IxS sh
cvtSShapeIxS ShNil = IZS
-cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh
+cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh
ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
ixCvtXS ShNil IZX = IZS
@@ -644,13 +646,13 @@ slift f (Shaped arr)
= Shaped (mlift f arr)
ssumOuter1 :: forall sh n a.
- (Storable a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
+ (Storable a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
. coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
- . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh))
+ . X.sumOuter (natSing @n :$@ SZX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
$ arr