summaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-13 20:54:08 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-13 20:54:08 +0200
commit8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (patch)
tree87729f0ed4145645e77fa0b065f6efc83ebb229d /src/Data/Array/Mixed.hs
parentd5e02224b16b9d616e5193e0b9c48865b2415699 (diff)
Switch to GHC.TypeLits.Nat for shapes
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs29
1 files changed, 18 insertions, 11 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 12c247f..2875203 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -22,6 +23,13 @@ import Unsafe.Coerce (unsafeCoerce)
import Data.Nat
+-- | The 'GHC.SNat' pattern synonym is complete, but it doesn't have a
+-- @COMPLETE@ pragma. This copy of it does.
+pattern GHC_SNat :: () => GHC.KnownNat n => GHC.SNat n
+pattern GHC_SNat = GHC.SNat
+{-# COMPLETE GHC_SNat #-}
+
+
-- | Type-level list append.
type family l1 ++ l2 where
'[] ++ l2 = l2
@@ -34,7 +42,7 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
-type IxX :: [Maybe Nat] -> Type
+type IxX :: [Maybe GHC.Nat] -> Type
data IxX sh where
IZX :: IxX '[]
(::@) :: Int -> IxX sh -> IxX (Just n : sh)
@@ -44,23 +52,23 @@ infixr 5 ::@
infixr 5 ::?
-- | The part of a shape that is statically known.
-type StaticShapeX :: [Maybe Nat] -> Type
+type StaticShapeX :: [Maybe GHC.Nat] -> Type
data StaticShapeX sh where
SZX :: StaticShapeX '[]
- (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
+ (:$@) :: GHC.SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
infixr 5 :$@
infixr 5 :$?
-- | Evidence for the static part of a shape.
-type KnownShapeX :: [Maybe Nat] -> Constraint
+type KnownShapeX :: [Maybe GHC.Nat] -> Constraint
class KnownShapeX sh where
knownShapeX :: StaticShapeX sh
instance KnownShapeX '[] where
knownShapeX = SZX
-instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = knownNat :$@ knownShapeX
+instance (GHC.KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
+ knownShapeX = GHC.natSing :$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :$? knownShapeX
@@ -68,7 +76,7 @@ type family Rank sh where
Rank '[] = Z
Rank (_ : sh) = S (Rank sh)
-type XArray :: [Maybe Nat] -> Type -> Type
+type XArray :: [Maybe GHC.Nat] -> Type -> Type
data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
deriving (Show)
@@ -176,14 +184,13 @@ lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
lemKnownShapeX SZX = Dict
-lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict
+lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (n :$@ ssh) ssh'
+lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
- , Dict <- snatKnown n
= Dict
lemAppKnownShapeX (() :$? ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
@@ -194,7 +201,7 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
- go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l
+ go (n :$@ ssh) (_ : l) = fromIntegral (GHC.fromSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"