From 8a81f7ea9eed9afaec948910caaf0a5c498de6c6 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 13 Apr 2024 20:54:08 +0200 Subject: Switch to GHC.TypeLits.Nat for shapes --- src/Data/Array/Mixed.hs | 29 ++++++++++++++++++----------- src/Data/Array/Nested/Internal.hs | 37 ++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 28 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 12c247f..2875203 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -2,6 +2,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -22,6 +23,13 @@ import Unsafe.Coerce (unsafeCoerce) import Data.Nat +-- | The 'GHC.SNat' pattern synonym is complete, but it doesn't have a +-- @COMPLETE@ pragma. This copy of it does. +pattern GHC_SNat :: () => GHC.KnownNat n => GHC.SNat n +pattern GHC_SNat = GHC.SNat +{-# COMPLETE GHC_SNat #-} + + -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 @@ -34,7 +42,7 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl -type IxX :: [Maybe Nat] -> Type +type IxX :: [Maybe GHC.Nat] -> Type data IxX sh where IZX :: IxX '[] (::@) :: Int -> IxX sh -> IxX (Just n : sh) @@ -44,23 +52,23 @@ infixr 5 ::@ infixr 5 ::? -- | The part of a shape that is statically known. -type StaticShapeX :: [Maybe Nat] -> Type +type StaticShapeX :: [Maybe GHC.Nat] -> Type data StaticShapeX sh where SZX :: StaticShapeX '[] - (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) + (:$@) :: GHC.SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) infixr 5 :$@ infixr 5 :$? -- | Evidence for the static part of a shape. -type KnownShapeX :: [Maybe Nat] -> Constraint +type KnownShapeX :: [Maybe GHC.Nat] -> Constraint class KnownShapeX sh where knownShapeX :: StaticShapeX sh instance KnownShapeX '[] where knownShapeX = SZX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = knownNat :$@ knownShapeX +instance (GHC.KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where + knownShapeX = GHC.natSing :$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :$? knownShapeX @@ -68,7 +76,7 @@ type family Rank sh where Rank '[] = Z Rank (_ : sh) = S (Rank sh) -type XArray :: [Maybe Nat] -> Type -> Type +type XArray :: [Maybe GHC.Nat] -> Type -> Type data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) deriving (Show) @@ -176,14 +184,13 @@ lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX SZX = Dict -lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict +lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (n :$@ ssh) ssh' +lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' - , Dict <- snatKnown n = Dict lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' @@ -194,7 +201,7 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IxX sh' go SZX [] = IZX - go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l + go (n :$@ ssh) (_ : l) = fromIntegral (GHC.fromSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 840bb96..bdded69 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} @@ -17,12 +18,13 @@ {-| TODO: -* This module needs better structure with an Internal module and less public - exports etc. - * We should be more consistent in whether functions take a 'StaticShapeX' argument or a 'KnownShapeX' constraint. +* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point + being that we need to do induction over the former, but the latter need to be + able to get large. + -} module Data.Array.Nested.Internal where @@ -35,8 +37,9 @@ import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM +import qualified GHC.TypeLits as GHC -import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X import Data.Nat @@ -56,10 +59,10 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) go SZ = SZX go (SS n) = () :$? go n -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n lemRankReplicate _ = go (knownNat @n) where - go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m go SZ = Refl go (SS n) | Refl <- go n = Refl @@ -89,7 +92,7 @@ newtype Primitive a = Primitive a -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. -type Mixed :: [Maybe Nat] -> Type -> Type +type Mixed :: [Maybe GHC.Nat] -> Type -> Type data family Mixed sh a newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) @@ -113,7 +116,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -- | Internal helper data family mirrorring 'Mixed' that consists of mutable -- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type data family MixedVecs s sh a newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) @@ -311,7 +314,7 @@ mgenerate sh f where checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool checkBounds IZX SZX = True - checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' + checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a @@ -343,7 +346,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type +type Shaped :: [GHC.Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) @@ -427,18 +430,18 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where -- | The shape of a shape-typed array given as a list of 'SNat' values. data SShape sh where ShNil :: SShape '[] - ShCons :: SNat n -> SShape sh -> SShape (n : sh) + ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh) deriving instance Show (SShape sh) infixr 5 `ShCons` -- | A statically-known shape of a shape-typed array. class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape +instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape sshapeKnown :: SShape sh -> Dict KnownShape sh sshapeKnown ShNil = Dict -sshapeKnown (ShCons n sh) | Dict <- snatKnown n, Dict <- sshapeKnown sh = Dict +sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) @@ -596,7 +599,7 @@ rtranspose perm (Ranked arr) -- (traditionally called \"@Fin@\"). Note that because the shape of a -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. -type IxS :: [Nat] -> Type +type IxS :: [GHC.Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) @@ -604,7 +607,7 @@ infixr 5 ::$ cvtSShapeIxS :: SShape sh -> IxS sh cvtSShapeIxS ShNil = IZS -cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh +cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh ixCvtXS ShNil IZX = IZS @@ -640,13 +643,13 @@ slift f (Shaped arr) = Shaped (mlift f arr) ssumOuter1 :: forall sh n a. - (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) + (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) - . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) + . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr -- cgit v1.2.3-70-g09d2