diff options
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 61 |
1 files changed, 40 insertions, 21 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 54b567a..222247b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -155,15 +155,26 @@ snatPred snp1 = EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) GTI -> Nothing + +-- Stupid things that the type checker should be able to figure out in-line, but can't + subst1 :: forall f a b. a :~: b -> f a :~: f b subst1 Refl = Refl subst2 :: forall f c a b. a :~: b -> f a c :~: f b c subst2 Refl = Refl +-- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that. lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +knownNatSucc :: KnownNat n => Dict KnownNat (1 + n) +knownNatSucc = Dict + + lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where @@ -947,7 +958,7 @@ type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Show i => Show (ListR n i) deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) @@ -963,7 +974,7 @@ listRToList (i ::: is) = i : listRToList is knownListR :: ListR n i -> Dict KnownNat n knownListR ZR = Dict -knownListR (_ ::: l) | Dict <- knownListR l = Dict +knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m -- | An index into a rank-typed array. type role IxR nominal representational @@ -1040,11 +1051,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX -ixCvtRX (n :.: idx) = n :.? ixCvtRX idx +ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX -shCvtRX (n :$: idx) = n :$? shCvtRX idx +shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) shapeSizeR :: IShR n -> Int shapeSizeR ZSR = 1 @@ -1084,19 +1095,19 @@ rlift f (Ranked arr) = Ranked (mlift f arr) rsumOuter1P :: forall n a. - (Storable a, Num a, KnownNat n, 1 <= n) - => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a) + (Storable a, Num a, KnownNat n) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked - . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a)) - . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing)) - . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a) + . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) + . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) + . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a) $ arr -rsumOuter1 :: forall n a. - (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n) - => Ranked n a -> Ranked (n - 1) a +rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n) + => Ranked (1 + n) a -> Ranked n a rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a @@ -1104,9 +1115,12 @@ rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mtranspose perm arr) -rappend :: forall n a. (KnownNat n, Elt a, 1 <= n) - => Ranked n a -> Ranked n a -> Ranked n a -rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend +rappend :: forall n a. (KnownNat n, Elt a) + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) @@ -1125,16 +1139,19 @@ rtoVectorP = coerce mtoVectorP rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a +rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromList1 l | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromList1 (coerce l)) + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l)) rfromList :: Elt a => NonEmpty a -> Ranked 1 a rfromList = Ranked . mfromList1 . fmap mscalar -rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a] -rtoList (Ranked arr) = coerce (mtoList1 arr) +rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoList (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] rtoList1 = map runScalar . rtoList @@ -1154,8 +1171,10 @@ rconstant sh x = coerce fromPrimitive (rconstantP sh x) rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a rslice ivs = rlift $ \_ -> X.slice ivs -rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a -rrev1 = rlift $ \_ -> X.rev1 +rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 = rlift $ \(Proxy @sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh') rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) => IShR n' -> Ranked n a -> Ranked n' a |