aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-26 10:27:39 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-26 10:27:39 +0200
commit8db33035826609bf48e15a82742981a58a0b5982 (patch)
tree848bf1bcfbd31a67c01b740c0065870f837543eb /src/Data/Array
parenta6f2809ed7e245d5eee4704b152783b4672cc212 (diff)
Refactor the clever replicate-aware Show instances
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed/XArray.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs54
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs16
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs18
4 files changed, 43 insertions, 54 deletions
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index 91a11ed..93484dc 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -15,6 +15,9 @@
module Data.Array.Mixed.XArray where
import Control.DeepSeq (NFData)
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
+import Data.Array.Internal qualified as OI
import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
@@ -59,6 +62,12 @@ fromVector sh v
toVector :: Storable a => XArray sh a -> VS.Vector a
toVector (XArray arr) = S.toVector arr
+-- | This allows observing the strides in the underlying orthotope array. This
+-- can be useful for optimisation, but should be considered an implementation
+-- detail: strides may change in new versions of this library without notice.
+arrayStrides :: XArray sh a -> [Int]
+arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides
+
scalar :: Storable a => a -> XArray '[] a
scalar = XArray . S.scalar
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 3bfda19..ca90889 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -24,9 +24,6 @@ import Control.DeepSeq (NFData(..))
import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.RankedS qualified as S
-import Data.Array.Internal.RankedS qualified as Internal.RankedS
-import Data.Array.Internal.RankedG qualified as Internal.RankedG
-import Data.Array.Internal qualified as Array.Internal
import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
@@ -52,6 +49,8 @@ import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Lemmas
+import Data.Bag
+
-- TODO:
-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
@@ -191,19 +190,24 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s
data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+showsMixedArray :: (Show a, Elt a)
+ => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@
+ -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@
+ -> Int -> Mixed sh a -> ShowS
+showsMixedArray fromlistPrefix replicatePrefix d arr =
+ showParen (d > 10) $
+ -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here
+ case mtoListLinear arr of
+ hd : _ : _
+ | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) ->
+ showString replicatePrefix . showString " " . showsPrec 11 hd
+ _ ->
+ showString fromlistPrefix . showString " " . shows (mtoListLinear arr)
+
instance (Show a, Elt a) => Show (Mixed sh a) where
- showsPrec d arr = showParen (d > 10) $
- let defaultResult =
- -- TODO: to avoid ambiguity, this should type-apply the shape to mfromListLinear
- showString "mfromListLinear " . shows (shxToList (mshape arr)) . showString " "
- . shows (mtoListLinear arr)
- in if stridesAreZero (shxLength $ mshape arr) (mstrideTree arr)
- then case mtoListLinear arr of
- [] -> defaultResult
- [_] -> defaultResult
- hd : _ -> showString "mreplicate " . shows (shxToList (mshape arr)) . showString " "
- . showsPrec 11 hd
- else defaultResult
+ showsPrec d arr =
+ let sh = show (shxToList (mshape arr))
+ in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr
instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
@@ -269,16 +273,6 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem)
matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
-data StrideTree =
- StrideLeaf [Int]
- | StrideNode StrideTree StrideTree
-
-stridesAreZero :: Int -> StrideTree -> Bool
-stridesAreZero prefixLen (StrideLeaf ss) =
- all (== 0) (take prefixLen ss)
-stridesAreZero prefixLen (StrideNode ss1 ss2) =
- stridesAreZero prefixLen ss1 && stridesAreZero prefixLen ss2
-
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
@@ -351,7 +345,9 @@ class Elt a where
mshowShapeTree :: Proxy a -> ShapeTree a -> String
- mstrideTree :: Mixed sh a -> StrideTree
+ -- | Returns the stride vector of each underlying component array making up
+ -- this mixed array.
+ marrayStrides :: Mixed sh a -> Bag [Int]
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
@@ -449,7 +445,7 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
- mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss
+ marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-- TODO: this use of toVector is suboptimal
@@ -523,7 +519,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
- mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b)
+ marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
@@ -656,7 +652,7 @@ instance Elt a => Elt (Mixed sh' a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mstrideTree (M_Nest _ arr) = mstrideTree arr
+ marrayStrides (M_Nest _ arr) = marrayStrides arr
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index cb8aae0..2aba1bc 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -65,17 +65,9 @@ deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
instance (Show a, Elt a) => Show (Ranked n a) where
- showsPrec d arr@(Ranked marr) = showParen (d > 10) $
- let defaultResult =
- showString "rfromListLinear " . shows (toList (rshape arr)) . showString " "
- . shows (rtoListLinear arr)
- in if stridesAreZero (shxLength $ mshape marr) (mstrideTree marr)
- then case rtoListLinear arr of
- [] -> defaultResult
- [_] -> defaultResult
- hd : _ -> showString "rreplicate " . shows (toList (rshape arr)) . showString " "
- . showsPrec 11 hd
- else defaultResult
+ showsPrec d arr@(Ranked marr) =
+ let sh = show (toList (rshape arr))
+ in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
instance Elt a => NFData (Ranked n a) where
rnf (Ranked arr) = rnf arr
@@ -152,7 +144,7 @@ instance Elt a => Elt (Ranked n a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mstrideTree (M_Ranked arr) = mstrideTree arr
+ marrayStrides (M_Ranked arr) = marrayStrides arr
mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs =
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index d5c9612..b7cb14d 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -65,18 +65,10 @@ newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
-instance (Show a, Elt a) => Show (Shaped sh a) where
- showsPrec d arr@(Shaped marr) = showParen (d > 10) $
- let defaultResult =
- showString "sfromListLinear " . shows (shsToList (sshape arr)) . showString " "
- . shows (stoListLinear arr)
- in if stridesAreZero (shxLength $ mshape marr) (mstrideTree marr)
- then case stoListLinear arr of
- [] -> defaultResult
- [_] -> defaultResult
- hd : _ -> showString "sreplicate " . shows (shsToList (sshape arr)) . showString " "
- . showsPrec 11 hd
- else defaultResult
+instance (Show a, Elt a) => Show (Shaped n a) where
+ showsPrec d arr@(Shaped marr) =
+ let sh = show (shsToList (sshape arr))
+ in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
instance Elt a => NFData (Shaped sh a) where
rnf (Shaped arr) = rnf arr
@@ -150,7 +142,7 @@ instance Elt a => Elt (Shaped sh a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mstrideTree (M_Shaped arr) = mstrideTree arr
+ marrayStrides (M_Shaped arr) = marrayStrides arr
mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs =