diff options
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 18 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 1 | ||||
-rw-r--r-- | src/Data/Array/Nested/Types.hs | 4 |
5 files changed, 17 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 2438f68..07777d5 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -136,7 +136,7 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i ixxFromIxR ZIR = ZIX ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) (n :.% ixxFromIxR idx) ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i @@ -146,7 +146,7 @@ ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i shxFromShR ZSR = ZSX shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) (SUnknown n :$% shxFromShR idx) shxFromShS :: ShS sh -> IShX (MapJust sh) diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e8c3b9e..e089479 100644 --- a/src/Data/Array/Nested/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -43,6 +43,8 @@ lemAppLeft _ Refl = Refl lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +{- for now, the plugins can't derive a type for this code, see + https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214 lemReplicatePlusApp sn _ _ = go sn where go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a @@ -51,6 +53,8 @@ lemReplicatePlusApp sn _ _ = go sn | Refl <- lemReplicateSucc @a n , Refl <- go n = sym (lemReplicateSucc @a (SNat @(n'm1 + m))) +-} +lemReplicatePlusApp _ _ _ = unsafeCoerceRefl lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 9778c54..8b95d0f 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -85,7 +85,7 @@ rsumOuter1P :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (NumElt a, PrimElt a) @@ -108,7 +108,7 @@ rtranspose perm arr rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce mconcat rappend :: forall n a. Elt a @@ -116,7 +116,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n) = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -142,7 +142,7 @@ rfromList1 l = Ranked (mfromList1 l) rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a @@ -161,7 +161,7 @@ rtoList = map runScalar . rtoListOuter rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoListLinear :: Elt a => Ranked n a -> [a] @@ -173,9 +173,9 @@ rfromOrthotope sn arr = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -255,7 +255,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = rlift (rrank arr) (\_ -> X.sliceU i n) arr @@ -264,7 +264,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 arr = rlift (rrank arr) (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of + case lemReplicateSucc @(Nothing @Nat) (Proxy @n) of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) arr diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 8b670e5..88a550c 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -138,7 +138,6 @@ listrCast = listrCastWithName "listrCast" listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" listrZip :: ListR n i -> ListR n j -> ListR n (i, j) listrZip ZR ZR = ZR diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index ba22e97..a43ae0c 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -109,8 +109,8 @@ type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a -lemReplicateSucc :: forall a n. - SNat n -> (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc :: forall a n proxy. + proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc _ = unsafeCoerceRefl type family MapJust l = r | r -> l where |