aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-05-15 13:29:10 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-05-15 13:30:36 +0200
commitbd11ee13d58c512f1a9cc0ef06b36c722653ff6f (patch)
treea9354a9c1874bd4aea77a217db7981708707d60e /src/Data/Array
parent43ddff2e7f1e9f4d8855f573384e26b63d34f697 (diff)
The code compiles with only GHC nats
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested.hs4
-rw-r--r--src/Data/Array/Nested/Internal.hs61
2 files changed, 40 insertions, 25 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index c7d1819..c12d8ad 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -37,9 +37,6 @@ module Data.Array.Nested (
PrimElt,
Primitive(..),
- -- * Inductive natural numbers
- module Data.INat,
-
-- * Further utilities / re-exports
type (++),
Storable,
@@ -49,5 +46,4 @@ import Prelude hiding (mappend)
import Data.Array.Mixed
import Data.Array.Nested.Internal
-import Data.INat
import Foreign.Storable
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 54b567a..222247b 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -155,15 +155,26 @@ snatPred snp1 =
EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
GTI -> Nothing
+
+-- Stupid things that the type checker should be able to figure out in-line, but can't
+
subst1 :: forall f a b. a :~: b -> f a :~: f b
subst1 Refl = Refl
subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl
+-- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that.
lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = unsafeCoerce Refl
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+knownNatSucc :: KnownNat n => Dict KnownNat (1 + n)
+knownNatSucc = Dict
+
+
lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
@@ -947,7 +958,7 @@ type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Show i => Show (ListR n i)
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
@@ -963,7 +974,7 @@ listRToList (i ::: is) = i : listRToList is
knownListR :: ListR n i -> Dict KnownNat n
knownListR ZR = Dict
-knownListR (_ ::: l) | Dict <- knownListR l = Dict
+knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m
-- | An index into a rank-typed array.
type role IxR nominal representational
@@ -1040,11 +1051,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
+ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
-shCvtRX (n :$: idx) = n :$? shCvtRX idx
+shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
shapeSizeR :: IShR n -> Int
shapeSizeR ZSR = 1
@@ -1084,19 +1095,19 @@ rlift f (Ranked arr)
= Ranked (mlift f arr)
rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownNat n, 1 <= n)
- => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a)
+ (Storable a, Num a, KnownNat n)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked
- . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a))
- . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing))
- . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a)
+ . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
+ . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
+ . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a)
$ arr
-rsumOuter1 :: forall n a.
- (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n)
- => Ranked n a -> Ranked (n - 1) a
+rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n)
+ => Ranked (1 + n) a -> Ranked n a
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
@@ -1104,9 +1115,12 @@ rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mtranspose perm arr)
-rappend :: forall n a. (KnownNat n, Elt a, 1 <= n)
- => Ranked n a -> Ranked n a -> Ranked n a
-rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
+rappend :: forall n a. (KnownNat n, Elt a)
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
@@ -1125,16 +1139,19 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a
+rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromList1 (coerce l))
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))
rfromList :: Elt a => NonEmpty a -> Ranked 1 a
rfromList = Ranked . mfromList1 . fmap mscalar
-rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a]
-rtoList (Ranked arr) = coerce (mtoList1 arr)
+rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoList (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
rtoList1 = map runScalar . rtoList
@@ -1154,8 +1171,10 @@ rconstant sh x = coerce fromPrimitive (rconstantP sh x)
rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
rslice ivs = rlift $ \_ -> X.slice ivs
-rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a
-rrev1 = rlift $ \_ -> X.rev1
+rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 = rlift $ \(Proxy @sh') ->
+ case lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')
rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
=> IShR n' -> Ranked n a -> Ranked n' a