aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs2
-rw-r--r--src/Data/Array/Nested/Internal.hs10
2 files changed, 8 insertions, 4 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 040b8d7..926c6ee 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -78,7 +78,7 @@ type family Rank sh where
Rank (_ : sh) = S (Rank sh)
type XArray :: [Maybe Nat] -> Type -> Type
-data XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
+newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
zeroIdx :: StaticShapeX sh -> IxX sh
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 15d72f0..eb4ef22 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -525,10 +525,14 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
vecs)
--- Utility function to satisfy the type checker sometimes
+-- Utility functions to satisfy the type checker sometimes
+
rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
rewriteMixed Refl x = x
+coerceMixedXArray :: Coercible (Mixed sh a) (XArray sh a) => XArray sh a -> Mixed sh a
+coerceMixedXArray = coerce
+
-- ====== API OF RANKED ARRAYS ====== --
@@ -583,7 +587,7 @@ rsumOuter1 :: forall n a.
rsumOuter1 (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked
- . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a)
+ . coerceMixedXArray
. X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing))
. coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
$ arr
@@ -651,7 +655,7 @@ ssumOuter1 :: forall sh n a.
ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
- . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
+ . coerceMixedXArray
. X.sumOuter (natSing @n :$@ SZX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
$ arr