aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ox-arrays.cabal1
-rw-r--r--src/Data/Array/Nested.hs4
-rw-r--r--src/Data/Array/Nested/Internal.hs61
-rw-r--r--src/Data/INat.hs121
-rw-r--r--test/Main.hs4
5 files changed, 42 insertions, 149 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index df6a805..a144629 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -10,7 +10,6 @@ library
Data.Array.Mixed
Data.Array.Nested
Data.Array.Nested.Internal
- Data.INat
build-depends:
base >=4.18 && <4.20,
ghc-typelits-knownnat,
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index c7d1819..c12d8ad 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -37,9 +37,6 @@ module Data.Array.Nested (
PrimElt,
Primitive(..),
- -- * Inductive natural numbers
- module Data.INat,
-
-- * Further utilities / re-exports
type (++),
Storable,
@@ -49,5 +46,4 @@ import Prelude hiding (mappend)
import Data.Array.Mixed
import Data.Array.Nested.Internal
-import Data.INat
import Foreign.Storable
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 54b567a..222247b 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -155,15 +155,26 @@ snatPred snp1 =
EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
GTI -> Nothing
+
+-- Stupid things that the type checker should be able to figure out in-line, but can't
+
subst1 :: forall f a b. a :~: b -> f a :~: f b
subst1 Refl = Refl
subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl
+-- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that.
lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = unsafeCoerce Refl
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+knownNatSucc :: KnownNat n => Dict KnownNat (1 + n)
+knownNatSucc = Dict
+
+
lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
@@ -947,7 +958,7 @@ type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Show i => Show (ListR n i)
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
@@ -963,7 +974,7 @@ listRToList (i ::: is) = i : listRToList is
knownListR :: ListR n i -> Dict KnownNat n
knownListR ZR = Dict
-knownListR (_ ::: l) | Dict <- knownListR l = Dict
+knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m
-- | An index into a rank-typed array.
type role IxR nominal representational
@@ -1040,11 +1051,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
+ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
-shCvtRX (n :$: idx) = n :$? shCvtRX idx
+shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
shapeSizeR :: IShR n -> Int
shapeSizeR ZSR = 1
@@ -1084,19 +1095,19 @@ rlift f (Ranked arr)
= Ranked (mlift f arr)
rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownNat n, 1 <= n)
- => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a)
+ (Storable a, Num a, KnownNat n)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked
- . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a))
- . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing))
- . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a)
+ . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
+ . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
+ . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a)
$ arr
-rsumOuter1 :: forall n a.
- (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n)
- => Ranked n a -> Ranked (n - 1) a
+rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n)
+ => Ranked (1 + n) a -> Ranked n a
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
@@ -1104,9 +1115,12 @@ rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mtranspose perm arr)
-rappend :: forall n a. (KnownNat n, Elt a, 1 <= n)
- => Ranked n a -> Ranked n a -> Ranked n a
-rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
+rappend :: forall n a. (KnownNat n, Elt a)
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
@@ -1125,16 +1139,19 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a
+rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromList1 (coerce l))
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))
rfromList :: Elt a => NonEmpty a -> Ranked 1 a
rfromList = Ranked . mfromList1 . fmap mscalar
-rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a]
-rtoList (Ranked arr) = coerce (mtoList1 arr)
+rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoList (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
rtoList1 = map runScalar . rtoList
@@ -1154,8 +1171,10 @@ rconstant sh x = coerce fromPrimitive (rconstantP sh x)
rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
rslice ivs = rlift $ \_ -> X.slice ivs
-rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a
-rrev1 = rlift $ \_ -> X.rev1
+rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 = rlift $ \(Proxy @sh') ->
+ case lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')
rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
=> IShR n' -> Ranked n a -> Ranked n' a
diff --git a/src/Data/INat.hs b/src/Data/INat.hs
deleted file mode 100644
index af8f18b..0000000
--- a/src/Data/INat.hs
+++ /dev/null
@@ -1,121 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.INat where
-
-import Data.Proxy
-import Data.Type.Equality ((:~:) (Refl))
-import Numeric.Natural
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-
--- | Evidence for the constraint @c a@.
-data Dict c a where
- Dict :: c a => Dict c a
-
--- | An inductive peano natural number. Intended to be used at the type level.
-data INat = Z | S INat
- deriving (Show)
-
--- | Singleton for a 'INat'.
-data SINat n where
- SZ :: SINat Z
- SS :: SINat n -> SINat (S n)
-deriving instance Show (SINat n)
-
--- | A singleton 'SINat' corresponding to @n@.
-class KnownINat n where inatSing :: SINat n
-instance KnownINat Z where inatSing = SZ
-instance KnownINat n => KnownINat (S n) where inatSing = SS inatSing
-
--- | Explicitly bidirectional pattern synonym that converts between a singleton
--- 'SINat' and evidence of a 'KnownINat' constraint. Analogous to 'GHC.SNat'.
-pattern SINat' :: () => KnownINat n => SINat n
-pattern SINat' <- (snatKnown -> Dict)
- where SINat' = inatSing
-
--- | A 'KnownINat' dictionary is just a singleton natural, so we can create
--- evidence of 'KnownINat' given an 'SINat'.
-snatKnown :: SINat n -> Dict KnownINat n
-snatKnown SZ = Dict
-snatKnown (SS n) | Dict <- snatKnown n = Dict
-
--- | Convert a 'INat' to a normal number.
-fromINat :: INat -> Natural
-fromINat Z = 0
-fromINat (S n) = 1 + fromINat n
-
--- | Convert an 'SINat' to a normal number.
-fromSINat :: SINat n -> Natural
-fromSINat SZ = 0
-fromSINat (SS n) = 1 + fromSINat n
-
--- | The value of a known inductive natural as a value-level integer.
-inatVal :: forall n. KnownINat n => Proxy n -> Natural
-inatVal _ = fromSINat (inatSing @n)
-
--- | Add two 'INat's
-type family n +! m where
- Z +! m = m
- S n +! m = S (n +! m)
-
--- | Convert a 'INat' to a "GHC.TypeLits" 'G.Nat'.
-type family FromINat n where
- FromINat Z = 0
- FromINat (S n) = 1 + FromINat n
-
--- | Convert a "GHC.TypeLits" 'G.Nat' to a 'INat'.
-type family ToINat (n :: Nat) where
- ToINat 0 = Z
- ToINat n = S (ToINat (n - 1))
-
-lemInjectiveFromINat :: n :~: ToINat (FromINat n)
-lemInjectiveFromINat = unsafeCoerce Refl
-
-lemSuccFromINat :: Proxy n -> 1 + FromINat n :~: FromINat (S n)
-lemSuccFromINat _ = unsafeCoerce Refl
-
-lemAddFromINat :: Proxy m -> Proxy n
- -> FromINat m + FromINat n :~: FromINat (m +! n)
-lemAddFromINat _ = unsafeCoerce Refl
-
-lemInjectiveToINat :: n :~: FromINat (ToINat n)
-lemInjectiveToINat = unsafeCoerce Refl
-
-lemSuccToINat :: Proxy n -> ToINat (1 + n) :~: S (ToINat n)
-lemSuccToINat _ = unsafeCoerce Refl
-
-lemAddToINat :: Proxy m -> Proxy n -> ToINat (m + n) :~: ToINat m +! ToINat n
-lemAddToINat _ _ = unsafeCoerce Refl
-
--- | If an inductive 'INat' is known, then the corresponding "GHC.TypeLits"
--- 'G.Nat' is also known.
-knownNatFromINat :: KnownINat n => Proxy n -> Dict KnownNat (FromINat n)
-knownNatFromINat (Proxy @n) = go (SINat' @n)
- where
- go :: SINat m -> Dict KnownNat (FromINat m)
- go SZ = Dict
- go (SS n) | Dict <- go n = Dict
-
--- * Some type-level inductive naturals
-
-type I0 = Z
-type I1 = S I0
-type I2 = S I1
-type I3 = S I2
-type I4 = S I3
-type I5 = S I4
-type I6 = S I5
-type I7 = S I6
-type I8 = S I7
-type I9 = S I8
diff --git a/test/Main.hs b/test/Main.hs
index 2363813..783d985 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -6,7 +6,7 @@ module Main where
import Data.Array.Nested
-arr :: Ranked I2 (Shaped [2, 3] (Double, Int))
+arr :: Ranked 2 (Shaped [2, 3] (Double, Int))
arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) ->
let s = 24*i + 6*j + 3*k + l
@@ -15,7 +15,7 @@ arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
foo :: (Double, Int)
foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS)
-bad :: Ranked I2 (Ranked I1 Double)
+bad :: Ranked 2 (Ranked 1 Double)
bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
rgenerate (i :$: ZSR) $ \(k :.: ZIR) ->
let s = 24*i + 6*j + 3*k