summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Mixed.hs29
-rw-r--r--src/Data/Array/Nested/Internal.hs37
-rw-r--r--test/Main.hs4
3 files changed, 40 insertions, 30 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 12c247f..2875203 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -22,6 +23,13 @@ import Unsafe.Coerce (unsafeCoerce)
import Data.Nat
+-- | The 'GHC.SNat' pattern synonym is complete, but it doesn't have a
+-- @COMPLETE@ pragma. This copy of it does.
+pattern GHC_SNat :: () => GHC.KnownNat n => GHC.SNat n
+pattern GHC_SNat = GHC.SNat
+{-# COMPLETE GHC_SNat #-}
+
+
-- | Type-level list append.
type family l1 ++ l2 where
'[] ++ l2 = l2
@@ -34,7 +42,7 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
-type IxX :: [Maybe Nat] -> Type
+type IxX :: [Maybe GHC.Nat] -> Type
data IxX sh where
IZX :: IxX '[]
(::@) :: Int -> IxX sh -> IxX (Just n : sh)
@@ -44,23 +52,23 @@ infixr 5 ::@
infixr 5 ::?
-- | The part of a shape that is statically known.
-type StaticShapeX :: [Maybe Nat] -> Type
+type StaticShapeX :: [Maybe GHC.Nat] -> Type
data StaticShapeX sh where
SZX :: StaticShapeX '[]
- (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
+ (:$@) :: GHC.SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
infixr 5 :$@
infixr 5 :$?
-- | Evidence for the static part of a shape.
-type KnownShapeX :: [Maybe Nat] -> Constraint
+type KnownShapeX :: [Maybe GHC.Nat] -> Constraint
class KnownShapeX sh where
knownShapeX :: StaticShapeX sh
instance KnownShapeX '[] where
knownShapeX = SZX
-instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = knownNat :$@ knownShapeX
+instance (GHC.KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
+ knownShapeX = GHC.natSing :$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :$? knownShapeX
@@ -68,7 +76,7 @@ type family Rank sh where
Rank '[] = Z
Rank (_ : sh) = S (Rank sh)
-type XArray :: [Maybe Nat] -> Type -> Type
+type XArray :: [Maybe GHC.Nat] -> Type -> Type
data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
deriving (Show)
@@ -176,14 +184,13 @@ lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
lemKnownShapeX SZX = Dict
-lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict
+lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (n :$@ ssh) ssh'
+lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
- , Dict <- snatKnown n
= Dict
lemAppKnownShapeX (() :$? ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
@@ -194,7 +201,7 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
- go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l
+ go (n :$@ ssh) (_ : l) = fromIntegral (GHC.fromSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 840bb96..bdded69 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
@@ -17,12 +18,13 @@
{-|
TODO:
-* This module needs better structure with an Internal module and less public
- exports etc.
-
* We should be more consistent in whether functions take a 'StaticShapeX'
argument or a 'KnownShapeX' constraint.
+* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point
+ being that we need to do induction over the former, but the latter need to be
+ able to get large.
+
-}
module Data.Array.Nested.Internal where
@@ -35,8 +37,9 @@ import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
+import qualified GHC.TypeLits as GHC
-import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
+import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
import qualified Data.Array.Mixed as X
import Data.Nat
@@ -56,10 +59,10 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
go SZ = SZX
go (SS n) = () :$? go n
-lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n
lemRankReplicate _ = go (knownNat @n)
where
- go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m
go SZ = Refl
go (SS n) | Refl <- go n = Refl
@@ -89,7 +92,7 @@ newtype Primitive a = Primitive a
--
-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
-- class.
-type Mixed :: [Maybe Nat] -> Type -> Type
+type Mixed :: [Maybe GHC.Nat] -> Type -> Type
data family Mixed sh a
newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
@@ -113,7 +116,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-- | Internal helper data family mirrorring 'Mixed' that consists of mutable
-- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type
data family MixedVecs s sh a
newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
@@ -311,7 +314,7 @@ mgenerate sh f
where
checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
checkBounds IZX SZX = True
- checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
+ checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
@@ -343,7 +346,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
-- and 'Shaped' itself is again an instance of 'Elt' as well.
--
-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
+type Shaped :: [GHC.Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
@@ -427,18 +430,18 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where
-- | The shape of a shape-typed array given as a list of 'SNat' values.
data SShape sh where
ShNil :: SShape '[]
- ShCons :: SNat n -> SShape sh -> SShape (n : sh)
+ ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh)
deriving instance Show (SShape sh)
infixr 5 `ShCons`
-- | A statically-known shape of a shape-typed array.
class KnownShape sh where knownShape :: SShape sh
instance KnownShape '[] where knownShape = ShNil
-instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
+instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape
sshapeKnown :: SShape sh -> Dict KnownShape sh
sshapeKnown ShNil = Dict
-sshapeKnown (ShCons n sh) | Dict <- snatKnown n, Dict <- sshapeKnown sh = Dict
+sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict
lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
@@ -596,7 +599,7 @@ rtranspose perm (Ranked arr)
-- (traditionally called \"@Fin@\"). Note that because the shape of a
-- shape-typed array is known statically, you can also retrieve the array shape
-- from a 'KnownShape' dictionary.
-type IxS :: [Nat] -> Type
+type IxS :: [GHC.Nat] -> Type
data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
@@ -604,7 +607,7 @@ infixr 5 ::$
cvtSShapeIxS :: SShape sh -> IxS sh
cvtSShapeIxS ShNil = IZS
-cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh
+cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh
ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
ixCvtXS ShNil IZX = IZS
@@ -640,13 +643,13 @@ slift f (Shaped arr)
= Shaped (mlift f arr)
ssumOuter1 :: forall sh n a.
- (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
+ (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
. coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
- . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
+ . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
$ arr
diff --git a/test/Main.hs b/test/Main.hs
index 8257ff0..7e1a3a1 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -6,9 +6,9 @@ module Main where
import Data.Array.Nested
-arr :: Ranked N2 (Shaped [N2, N3] (Double, Int))
+arr :: Ranked N2 (Shaped [2, 3] (Double, Int))
arr = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) ->
- sgenerate @[N2, N3] (2 ::$ 3 ::$ IZS) $ \(k ::$ l ::$ IZS) ->
+ sgenerate @[2, 3] (2 ::$ 3 ::$ IZS) $ \(k ::$ l ::$ IZS) ->
let s = 24*i + 6*j + 3*k + l
in (fromIntegral s, s)