diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2024-05-15 13:29:10 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-05-15 13:30:36 +0200 | 
| commit | bd11ee13d58c512f1a9cc0ef06b36c722653ff6f (patch) | |
| tree | a9354a9c1874bd4aea77a217db7981708707d60e /src/Data/Array/Nested | |
| parent | 43ddff2e7f1e9f4d8855f573384e26b63d34f697 (diff) | |
The code compiles with only GHC nats
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 61 | 
1 files changed, 40 insertions, 21 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 54b567a..222247b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -155,15 +155,26 @@ snatPred snp1 =        EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)        GTI -> Nothing + +-- Stupid things that the type checker should be able to figure out in-line, but can't +  subst1 :: forall f a b. a :~: b -> f a :~: f b  subst1 Refl = Refl  subst2 :: forall f c a b. a :~: b -> f a c :~: f b c  subst2 Refl = Refl +-- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that.  lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a  lemReplicateSucc = unsafeCoerce Refl +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +knownNatSucc :: KnownNat n => Dict KnownNat (1 + n) +knownNatSucc = Dict + +  lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)  lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))    where @@ -947,7 +958,7 @@ type role ListR nominal representational  type ListR :: Nat -> Type -> Type  data ListR n i where    ZR :: ListR 0 i -  (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i +  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i  deriving instance Show i => Show (ListR n i)  deriving instance Eq i => Eq (ListR n i)  deriving instance Ord i => Ord (ListR n i) @@ -963,7 +974,7 @@ listRToList (i ::: is) = i : listRToList is  knownListR :: ListR n i -> Dict KnownNat n  knownListR ZR = Dict -knownListR (_ ::: l) | Dict <- knownListR l = Dict +knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m  -- | An index into a rank-typed array.  type role IxR nominal representational @@ -1040,11 +1051,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx  ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)  ixCvtRX ZIR = ZIX -ixCvtRX (n :.: idx) = n :.? ixCvtRX idx +ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)  shCvtRX :: IShR n -> IShX (Replicate n Nothing)  shCvtRX ZSR = ZSX -shCvtRX (n :$: idx) = n :$? shCvtRX idx +shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)  shapeSizeR :: IShR n -> Int  shapeSizeR ZSR = 1 @@ -1084,19 +1095,19 @@ rlift f (Ranked arr)    = Ranked (mlift f arr)  rsumOuter1P :: forall n a. -               (Storable a, Num a, KnownNat n, 1 <= n) -            => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a) +               (Storable a, Num a, KnownNat n) +            => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)  rsumOuter1P (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n) +  , Refl <- lemReplicateSucc @(Nothing @Nat) @n    = Ranked -    . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a)) -    . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing)) -    . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a) +    . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) +    . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) +    . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a)      $ arr -rsumOuter1 :: forall n a. -  (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n) -           => Ranked n a -> Ranked (n - 1) a +rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n) +           => Ranked (1 + n) a -> Ranked n a  rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive  rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a @@ -1104,9 +1115,12 @@ rtranspose perm (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n)    = Ranked (mtranspose perm arr) -rappend :: forall n a. (KnownNat n, Elt a, 1 <= n) -        => Ranked n a -> Ranked n a -> Ranked n a -rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend +rappend :: forall n a. (KnownNat n, Elt a) +        => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend +  | Dict <- lemKnownReplicate (Proxy @n) +  , Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))  rscalar :: Elt a => a -> Ranked 0 a  rscalar x = Ranked (mscalar x) @@ -1125,16 +1139,19 @@ rtoVectorP = coerce mtoVectorP  rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a  rtoVector = coerce mtoVector -rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a +rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a  rfromList1 l    | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mfromList1 (coerce l)) +  , Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))  rfromList :: Elt a => NonEmpty a -> Ranked 1 a  rfromList = Ranked . mfromList1 . fmap mscalar -rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a] -rtoList (Ranked arr) = coerce (mtoList1 arr) +rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoList (Ranked arr) +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)  rtoList1 :: Elt a => Ranked 1 a -> [a]  rtoList1 = map runScalar . rtoList @@ -1154,8 +1171,10 @@ rconstant sh x = coerce fromPrimitive (rconstantP sh x)  rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a  rslice ivs = rlift $ \_ -> X.slice ivs -rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a -rrev1 = rlift $ \_ -> X.rev1 +rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 = rlift $ \(Proxy @sh') -> +  case lemReplicateSucc @(Nothing @Nat) @n of +    Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')  rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)           => IShR n' -> Ranked n a -> Ranked n' a | 
