summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs25
1 files changed, 12 insertions, 13 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index bf408e6..fadf1a7 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -608,9 +608,10 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
| Dict <- lemKnownReplicate (Proxy @n)
= M_Ranked (mfromList (coerce l))
+ mtoList :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
mtoList (M_Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
- = coerce (mtoList arr)
+ = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList arr)
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -736,9 +737,10 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= M_Shaped (mfromList (coerce l))
+ mtoList :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
mtoList (M_Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
- = coerce (mtoList arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList arr)
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -819,9 +821,6 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
rewriteMixed Refl x = x
-coerceMixedXArray :: Coercible (Mixed sh a) (XArray sh a) => XArray sh a -> Mixed sh a
-coerceMixedXArray = coerce
-
-- ====== API OF RANKED ARRAYS ====== --
@@ -914,14 +913,14 @@ rlift f (Ranked arr)
= Ranked (mlift f arr)
rsumOuter1 :: forall n a.
- (Storable a, Num a, KnownINat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
- => Ranked (S n) a -> Ranked n a
+ (Storable a, Num a, KnownINat n)
+ => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1 (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked
- . coerceMixedXArray
+ . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
. X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
+ . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
$ arr
rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
@@ -1059,14 +1058,14 @@ slift f (Shaped arr)
= Shaped (mlift f arr)
ssumOuter1 :: forall sh n a.
- (Storable a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
- => Shaped (n : sh) a -> Shaped sh a
+ (Storable a, Num a, KnownNat n, KnownShape sh)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
- . coerceMixedXArray
+ . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
. X.sumOuter (natSing @n :$@ SZX) (knownShapeX @(MapJust sh))
- . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
+ . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
$ arr
stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a