aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ox-arrays.cabal4
-rw-r--r--src/Data/Array/Nested/Convert.hs4
-rw-r--r--src/Data/Array/Nested/Lemmas.hs4
-rw-r--r--src/Data/Array/Nested/Ranked.hs18
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs1
-rw-r--r--src/Data/Array/Nested/Types.hs4
6 files changed, 19 insertions, 16 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index be4bb03..142e4ed 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -115,8 +115,8 @@ library strided-array-ops
Data.Array.Strided.Arith.Internal.Lists.TH
build-depends:
base >=4.18 && <4.22,
- ghc-typelits-knownnat < 1,
- ghc-typelits-natnormalise < 1,
+ ghc-typelits-knownnat >= 0.8.0 && < 1,
+ ghc-typelits-natnormalise >= 0.8.0 && < 1,
template-haskell < 3,
vector < 0.14
hs-source-dirs: ops
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 2438f68..07777d5 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -136,7 +136,7 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX
ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
ixxFromIxR ZIR = ZIX
ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
(n :.% ixxFromIxR idx)
ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
@@ -146,7 +146,7 @@ ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
shxFromShR ZSR = ZSX
shxFromShR (n :$: (idx :: ShR m i)) =
- castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
(SUnknown n :$% shxFromShR idx)
shxFromShS :: ShS sh -> IShX (MapJust sh)
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index e8c3b9e..e089479 100644
--- a/src/Data/Array/Nested/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -43,6 +43,8 @@ lemAppLeft _ Refl = Refl
lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+{- for now, the plugins can't derive a type for this code, see
+ https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214
lemReplicatePlusApp sn _ _ = go sn
where
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
@@ -51,6 +53,8 @@ lemReplicatePlusApp sn _ _ = go sn
| Refl <- lemReplicateSucc @a n
, Refl <- go n
= sym (lemReplicateSucc @a (SNat @(n'm1 + m)))
+-}
+lemReplicatePlusApp _ _ _ = unsafeCoerceRefl
lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 9778c54..8b95d0f 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -85,7 +85,7 @@ rsumOuter1P :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (msumOuter1P arr)
rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
@@ -108,7 +108,7 @@ rtranspose perm arr
rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
rconcat
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= coerce mconcat
rappend :: forall n a. Elt a
@@ -116,7 +116,7 @@ rappend :: forall n a. Elt a
rappend arr1 arr2
| sn@SNat <- rrank arr1
, Dict <- lemKnownReplicate sn
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n)
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
arr1 arr2
@@ -142,7 +142,7 @@ rfromList1 l = Ranked (mfromList1 l)
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
@@ -161,7 +161,7 @@ rtoList = map runScalar . rtoListOuter
rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoListOuter (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoListLinear :: Elt a => Ranked n a -> [a]
@@ -173,9 +173,9 @@ rfromOrthotope sn arr
= let xarr = XArray arr
in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
-rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh)
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -255,7 +255,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= rlift (rrank arr)
(\_ -> X.sliceU i n)
arr
@@ -264,7 +264,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
rlift (rrank arr)
(\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
+ case lemReplicateSucc @(Nothing @Nat) (Proxy @n) of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
arr
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 8b670e5..88a550c 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -138,7 +138,6 @@ listrCast = listrCastWithName "listrCast"
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
listrZip ZR ZR = ZR
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index ba22e97..a43ae0c 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -109,8 +109,8 @@ type family Replicate n a where
Replicate 0 a = '[]
Replicate n a = a : Replicate (n - 1) a
-lemReplicateSucc :: forall a n.
- SNat n -> (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc :: forall a n proxy.
+ proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc _ = unsafeCoerceRefl
type family MapJust l = r | r -> l where