summaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs62
1 files changed, 31 insertions, 31 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 2bbf81d..049a0c4 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -17,16 +17,16 @@ import Data.Kind
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
-import qualified GHC.TypeLits as GHC
+import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
-import Data.Nat
+import Data.INat
--- | The 'GHC.SNat' pattern synonym is complete, but it doesn't have a
+-- | The 'SNat' pattern synonym is complete, but it doesn't have a
-- @COMPLETE@ pragma. This copy of it does.
-pattern GHC_SNat :: () => GHC.KnownNat n => GHC.SNat n
-pattern GHC_SNat = GHC.SNat
+pattern GHC_SNat :: () => KnownNat n => SNat n
+pattern GHC_SNat = SNat
{-# COMPLETE GHC_SNat #-}
@@ -42,7 +42,7 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
-type IxX :: [Maybe GHC.Nat] -> Type
+type IxX :: [Maybe Nat] -> Type
data IxX sh where
IZX :: IxX '[]
(::@) :: Int -> IxX sh -> IxX (Just n : sh)
@@ -52,23 +52,23 @@ infixr 5 ::@
infixr 5 ::?
-- | The part of a shape that is statically known.
-type StaticShapeX :: [Maybe GHC.Nat] -> Type
+type StaticShapeX :: [Maybe Nat] -> Type
data StaticShapeX sh where
SZX :: StaticShapeX '[]
- (:$@) :: GHC.SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
+ (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
infixr 5 :$@
infixr 5 :$?
-- | Evidence for the static part of a shape.
-type KnownShapeX :: [Maybe GHC.Nat] -> Constraint
+type KnownShapeX :: [Maybe Nat] -> Constraint
class KnownShapeX sh where
knownShapeX :: StaticShapeX sh
instance KnownShapeX '[] where
knownShapeX = SZX
-instance (GHC.KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = GHC.natSing :$@ knownShapeX
+instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
+ knownShapeX = natSing :$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :$? knownShapeX
@@ -76,8 +76,8 @@ type family Rank sh where
Rank '[] = Z
Rank (_ : sh) = S (Rank sh)
-type XArray :: [Maybe GHC.Nat] -> Type -> Type
-data XArray sh a = XArray (S.Array (GNat (Rank sh)) a)
+type XArray :: [Maybe Nat] -> Type -> Type
+data XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
zeroIdx :: StaticShapeX sh -> IxX sh
@@ -165,19 +165,19 @@ ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2
- -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2)
+ -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
- -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1))
+ -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank :: IxX sh -> Dict KnownINat (Rank sh)
lemKnownNatRank IZX = Dict
lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict
-lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh)
lemKnownNatRankSSX SZX = Dict
lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
@@ -201,14 +201,14 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
- go (n :$@ ssh) (_ : l) = fromIntegral (GHC.fromSNat n) ::@ go ssh l
+ go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
- , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.fromVector (shapeLshape sh) v)
toVector :: S.Unbox a => XArray sh a -> VS.Vector a
@@ -242,7 +242,7 @@ index xarr i
append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
append (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
- , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
@@ -252,14 +252,14 @@ rerank :: forall sh sh1 sh2 a b.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
| Dict <- lemKnownNatRankSSX ssh
- , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
, Dict <- lemKnownNatRankSSX ssh2
- , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh2))
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
(\a -> unXArray (f (XArray a)))
arr)
where
@@ -279,14 +279,14 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
| Dict <- lemKnownNatRankSSX ssh
- , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
, Dict <- lemKnownNatRankSSX ssh2
- , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh2))
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
@@ -296,7 +296,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
- , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
@@ -306,9 +306,9 @@ transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
, Refl <- lemRankApp ssh2 ssh1
, Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2)))
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2)))
, Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
- , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1)))
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)