diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-13 20:54:08 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-13 20:54:08 +0200 |
commit | 8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (patch) | |
tree | 87729f0ed4145645e77fa0b065f6efc83ebb229d /src/Data/Array/Nested | |
parent | d5e02224b16b9d616e5193e0b9c48865b2415699 (diff) |
Switch to GHC.TypeLits.Nat for shapes
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 840bb96..bdded69 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} @@ -17,12 +18,13 @@ {-| TODO: -* This module needs better structure with an Internal module and less public - exports etc. - * We should be more consistent in whether functions take a 'StaticShapeX' argument or a 'KnownShapeX' constraint. +* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point + being that we need to do induction over the former, but the latter need to be + able to get large. + -} module Data.Array.Nested.Internal where @@ -35,8 +37,9 @@ import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM +import qualified GHC.TypeLits as GHC -import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X import Data.Nat @@ -56,10 +59,10 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) go SZ = SZX go (SS n) = () :$? go n -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n lemRankReplicate _ = go (knownNat @n) where - go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m go SZ = Refl go (SS n) | Refl <- go n = Refl @@ -89,7 +92,7 @@ newtype Primitive a = Primitive a -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. -type Mixed :: [Maybe Nat] -> Type -> Type +type Mixed :: [Maybe GHC.Nat] -> Type -> Type data family Mixed sh a newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) @@ -113,7 +116,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -- | Internal helper data family mirrorring 'Mixed' that consists of mutable -- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type data family MixedVecs s sh a newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) @@ -311,7 +314,7 @@ mgenerate sh f where checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool checkBounds IZX SZX = True - checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' + checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a @@ -343,7 +346,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type +type Shaped :: [GHC.Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) @@ -427,18 +430,18 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where -- | The shape of a shape-typed array given as a list of 'SNat' values. data SShape sh where ShNil :: SShape '[] - ShCons :: SNat n -> SShape sh -> SShape (n : sh) + ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh) deriving instance Show (SShape sh) infixr 5 `ShCons` -- | A statically-known shape of a shape-typed array. class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape +instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape sshapeKnown :: SShape sh -> Dict KnownShape sh sshapeKnown ShNil = Dict -sshapeKnown (ShCons n sh) | Dict <- snatKnown n, Dict <- sshapeKnown sh = Dict +sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) @@ -596,7 +599,7 @@ rtranspose perm (Ranked arr) -- (traditionally called \"@Fin@\"). Note that because the shape of a -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. -type IxS :: [Nat] -> Type +type IxS :: [GHC.Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) @@ -604,7 +607,7 @@ infixr 5 ::$ cvtSShapeIxS :: SShape sh -> IxS sh cvtSShapeIxS ShNil = IZS -cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh +cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh ixCvtXS ShNil IZX = IZS @@ -640,13 +643,13 @@ slift f (Shaped arr) = Shaped (mlift f arr) ssumOuter1 :: forall sh n a. - (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) + (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) - . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) + . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr |