aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-13 20:54:08 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-13 20:54:08 +0200
commit8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (patch)
tree87729f0ed4145645e77fa0b065f6efc83ebb229d /src/Data/Array/Nested
parentd5e02224b16b9d616e5193e0b9c48865b2415699 (diff)
Switch to GHC.TypeLits.Nat for shapes
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs37
1 files changed, 20 insertions, 17 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 840bb96..bdded69 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
@@ -17,12 +18,13 @@
{-|
TODO:
-* This module needs better structure with an Internal module and less public
- exports etc.
-
* We should be more consistent in whether functions take a 'StaticShapeX'
argument or a 'KnownShapeX' constraint.
+* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point
+ being that we need to do induction over the former, but the latter need to be
+ able to get large.
+
-}
module Data.Array.Nested.Internal where
@@ -35,8 +37,9 @@ import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
+import qualified GHC.TypeLits as GHC
-import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
+import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
import qualified Data.Array.Mixed as X
import Data.Nat
@@ -56,10 +59,10 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
go SZ = SZX
go (SS n) = () :$? go n
-lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n
lemRankReplicate _ = go (knownNat @n)
where
- go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m
go SZ = Refl
go (SS n) | Refl <- go n = Refl
@@ -89,7 +92,7 @@ newtype Primitive a = Primitive a
--
-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
-- class.
-type Mixed :: [Maybe Nat] -> Type -> Type
+type Mixed :: [Maybe GHC.Nat] -> Type -> Type
data family Mixed sh a
newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
@@ -113,7 +116,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-- | Internal helper data family mirrorring 'Mixed' that consists of mutable
-- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type
data family MixedVecs s sh a
newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
@@ -311,7 +314,7 @@ mgenerate sh f
where
checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
checkBounds IZX SZX = True
- checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
+ checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
@@ -343,7 +346,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
-- and 'Shaped' itself is again an instance of 'Elt' as well.
--
-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
+type Shaped :: [GHC.Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
@@ -427,18 +430,18 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where
-- | The shape of a shape-typed array given as a list of 'SNat' values.
data SShape sh where
ShNil :: SShape '[]
- ShCons :: SNat n -> SShape sh -> SShape (n : sh)
+ ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh)
deriving instance Show (SShape sh)
infixr 5 `ShCons`
-- | A statically-known shape of a shape-typed array.
class KnownShape sh where knownShape :: SShape sh
instance KnownShape '[] where knownShape = ShNil
-instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
+instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape
sshapeKnown :: SShape sh -> Dict KnownShape sh
sshapeKnown ShNil = Dict
-sshapeKnown (ShCons n sh) | Dict <- snatKnown n, Dict <- sshapeKnown sh = Dict
+sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict
lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
@@ -596,7 +599,7 @@ rtranspose perm (Ranked arr)
-- (traditionally called \"@Fin@\"). Note that because the shape of a
-- shape-typed array is known statically, you can also retrieve the array shape
-- from a 'KnownShape' dictionary.
-type IxS :: [Nat] -> Type
+type IxS :: [GHC.Nat] -> Type
data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
@@ -604,7 +607,7 @@ infixr 5 ::$
cvtSShapeIxS :: SShape sh -> IxS sh
cvtSShapeIxS ShNil = IZS
-cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh
+cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh
ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
ixCvtXS ShNil IZX = IZS
@@ -640,13 +643,13 @@ slift f (Shaped arr)
= Shaped (mlift f arr)
ssumOuter1 :: forall sh n a.
- (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
+ (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
. coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
- . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
+ . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
$ arr