diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-13 20:54:08 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-13 20:54:08 +0200 |
commit | 8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (patch) | |
tree | 87729f0ed4145645e77fa0b065f6efc83ebb229d /src/Data/Array/Mixed.hs | |
parent | d5e02224b16b9d616e5193e0b9c48865b2415699 (diff) |
Switch to GHC.TypeLits.Nat for shapes
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 12c247f..2875203 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -2,6 +2,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -22,6 +23,13 @@ import Unsafe.Coerce (unsafeCoerce) import Data.Nat +-- | The 'GHC.SNat' pattern synonym is complete, but it doesn't have a +-- @COMPLETE@ pragma. This copy of it does. +pattern GHC_SNat :: () => GHC.KnownNat n => GHC.SNat n +pattern GHC_SNat = GHC.SNat +{-# COMPLETE GHC_SNat #-} + + -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 @@ -34,7 +42,7 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl -type IxX :: [Maybe Nat] -> Type +type IxX :: [Maybe GHC.Nat] -> Type data IxX sh where IZX :: IxX '[] (::@) :: Int -> IxX sh -> IxX (Just n : sh) @@ -44,23 +52,23 @@ infixr 5 ::@ infixr 5 ::? -- | The part of a shape that is statically known. -type StaticShapeX :: [Maybe Nat] -> Type +type StaticShapeX :: [Maybe GHC.Nat] -> Type data StaticShapeX sh where SZX :: StaticShapeX '[] - (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) + (:$@) :: GHC.SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) infixr 5 :$@ infixr 5 :$? -- | Evidence for the static part of a shape. -type KnownShapeX :: [Maybe Nat] -> Constraint +type KnownShapeX :: [Maybe GHC.Nat] -> Constraint class KnownShapeX sh where knownShapeX :: StaticShapeX sh instance KnownShapeX '[] where knownShapeX = SZX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = knownNat :$@ knownShapeX +instance (GHC.KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where + knownShapeX = GHC.natSing :$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :$? knownShapeX @@ -68,7 +76,7 @@ type family Rank sh where Rank '[] = Z Rank (_ : sh) = S (Rank sh) -type XArray :: [Maybe Nat] -> Type -> Type +type XArray :: [Maybe GHC.Nat] -> Type -> Type data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) deriving (Show) @@ -176,14 +184,13 @@ lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX SZX = Dict -lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict +lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (n :$@ ssh) ssh' +lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' - , Dict <- snatKnown n = Dict lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' @@ -194,7 +201,7 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IxX sh' go SZX [] = IZX - go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l + go (n :$@ ssh) (_ : l) = fromIntegral (GHC.fromSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" |