From 3defbdadab5080fc1f44895c06297d58ff3f5a43 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 12:28:05 +0200 Subject: Make XArray a newtype --- src/Data/Array/Mixed.hs | 2 +- src/Data/Array/Nested/Internal.hs | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 040b8d7..926c6ee 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -78,7 +78,7 @@ type family Rank sh where Rank (_ : sh) = S (Rank sh) type XArray :: [Maybe Nat] -> Type -> Type -data XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) +newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) deriving (Show) zeroIdx :: StaticShapeX sh -> IxX sh diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 15d72f0..eb4ef22 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -525,10 +525,14 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where vecs) --- Utility function to satisfy the type checker sometimes +-- Utility functions to satisfy the type checker sometimes + rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a rewriteMixed Refl x = x +coerceMixedXArray :: Coercible (Mixed sh a) (XArray sh a) => XArray sh a -> Mixed sh a +coerceMixedXArray = coerce + -- ====== API OF RANKED ARRAYS ====== -- @@ -583,7 +587,7 @@ rsumOuter1 :: forall n a. rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked - . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a) + . coerceMixedXArray . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) $ arr @@ -651,7 +655,7 @@ ssumOuter1 :: forall sh n a. ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped - . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) + . coerceMixedXArray . X.sumOuter (natSing @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr -- cgit v1.2.3-70-g09d2