diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-26 10:27:39 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-26 10:27:39 +0200 |
commit | 8db33035826609bf48e15a82742981a58a0b5982 (patch) | |
tree | 848bf1bcfbd31a67c01b740c0065870f837543eb /src/Data | |
parent | a6f2809ed7e245d5eee4704b152783b4672cc212 (diff) |
Refactor the clever replicate-aware Show instances
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 54 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 16 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 18 | ||||
-rw-r--r-- | src/Data/Bag.hs | 18 |
5 files changed, 61 insertions, 54 deletions
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index 91a11ed..93484dc 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -15,6 +15,9 @@ module Data.Array.Mixed.XArray where import Control.DeepSeq (NFData) +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS +import Data.Array.Internal qualified as OI import Data.Array.Ranked qualified as ORB import Data.Array.RankedS qualified as S import Data.Coerce @@ -59,6 +62,12 @@ fromVector sh v toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr +-- | This allows observing the strides in the underlying orthotope array. This +-- can be useful for optimisation, but should be considered an implementation +-- detail: strides may change in new versions of this library without notice. +arrayStrides :: XArray sh a -> [Int] +arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides + scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 3bfda19..ca90889 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -24,9 +24,6 @@ import Control.DeepSeq (NFData(..)) import Control.Monad (forM_, when) import Control.Monad.ST import Data.Array.RankedS qualified as S -import Data.Array.Internal.RankedS qualified as Internal.RankedS -import Data.Array.Internal.RankedG qualified as Internal.RankedG -import Data.Array.Internal qualified as Array.Internal import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) @@ -52,6 +49,8 @@ import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.Permutation import Data.Array.Mixed.Lemmas +import Data.Bag + -- TODO: -- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -191,19 +190,24 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) +showsMixedArray :: (Show a, Elt a) + => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ + -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ + -> Int -> Mixed sh a -> ShowS +showsMixedArray fromlistPrefix replicatePrefix d arr = + showParen (d > 10) $ + -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here + case mtoListLinear arr of + hd : _ : _ + | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> + showString replicatePrefix . showString " " . showsPrec 11 hd + _ -> + showString fromlistPrefix . showString " " . shows (mtoListLinear arr) + instance (Show a, Elt a) => Show (Mixed sh a) where - showsPrec d arr = showParen (d > 10) $ - let defaultResult = - -- TODO: to avoid ambiguity, this should type-apply the shape to mfromListLinear - showString "mfromListLinear " . shows (shxToList (mshape arr)) . showString " " - . shows (mtoListLinear arr) - in if stridesAreZero (shxLength $ mshape arr) (mstrideTree arr) - then case mtoListLinear arr of - [] -> defaultResult - [_] -> defaultResult - hd : _ -> showString "mreplicate " . shows (shxToList (mshape arr)) . showString " " - . showsPrec 11 hd - else defaultResult + showsPrec d arr = + let sh = show (shxToList (mshape arr)) + in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr instance Elt a => NFData (Mixed sh a) where rnf = mrnf @@ -269,16 +273,6 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) -data StrideTree = - StrideLeaf [Int] - | StrideNode StrideTree StrideTree - -stridesAreZero :: Int -> StrideTree -> Bool -stridesAreZero prefixLen (StrideLeaf ss) = - all (== 0) (take prefixLen ss) -stridesAreZero prefixLen (StrideNode ss1 ss2) = - stridesAreZero prefixLen ss1 && stridesAreZero prefixLen ss2 - -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. @@ -351,7 +345,9 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String - mstrideTree :: Mixed sh a -> StrideTree + -- | Returns the stride vector of each underlying component array making up + -- this mixed array. + marrayStrides :: Mixed sh a -> Bag [Int] -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. @@ -449,7 +445,7 @@ instance Storable a => Elt (Primitive a) where mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" - mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss + marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal @@ -523,7 +519,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b) + marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -656,7 +652,7 @@ instance Elt a => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mstrideTree (M_Nest _ arr) = mstrideTree arr + marrayStrides (M_Nest _ arr) = marrayStrides arr mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index cb8aae0..2aba1bc 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -65,17 +65,9 @@ deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) instance (Show a, Elt a) => Show (Ranked n a) where - showsPrec d arr@(Ranked marr) = showParen (d > 10) $ - let defaultResult = - showString "rfromListLinear " . shows (toList (rshape arr)) . showString " " - . shows (rtoListLinear arr) - in if stridesAreZero (shxLength $ mshape marr) (mstrideTree marr) - then case rtoListLinear arr of - [] -> defaultResult - [_] -> defaultResult - hd : _ -> showString "rreplicate " . shows (toList (rshape arr)) . showString " " - . showsPrec 11 hd - else defaultResult + showsPrec d arr@(Ranked marr) = + let sh = show (toList (rshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr instance Elt a => NFData (Ranked n a) where rnf (Ranked arr) = rnf arr @@ -152,7 +144,7 @@ instance Elt a => Elt (Ranked n a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mstrideTree (M_Ranked arr) = mstrideTree arr + marrayStrides (M_Ranked arr) = marrayStrides arr mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs = diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d5c9612..b7cb14d 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -65,18 +65,10 @@ newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) -instance (Show a, Elt a) => Show (Shaped sh a) where - showsPrec d arr@(Shaped marr) = showParen (d > 10) $ - let defaultResult = - showString "sfromListLinear " . shows (shsToList (sshape arr)) . showString " " - . shows (stoListLinear arr) - in if stridesAreZero (shxLength $ mshape marr) (mstrideTree marr) - then case stoListLinear arr of - [] -> defaultResult - [_] -> defaultResult - hd : _ -> showString "sreplicate " . shows (shsToList (sshape arr)) . showString " " - . showsPrec 11 hd - else defaultResult +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr instance Elt a => NFData (Shaped sh a) where rnf (Shaped arr) = rnf arr @@ -150,7 +142,7 @@ instance Elt a => Elt (Shaped sh a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mstrideTree (M_Shaped arr) = mstrideTree arr + marrayStrides (M_Shaped arr) = marrayStrides arr mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs = diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs new file mode 100644 index 0000000..84c770a --- /dev/null +++ b/src/Data/Bag.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE DeriveTraversable #-} +module Data.Bag where + + +-- | An ordered sequence that can be folded over. +data Bag a = BZero | BOne a | BTwo (Bag a) (Bag a) | BList [Bag a] + deriving (Functor, Foldable, Traversable) + +-- Really only here for 'pure' +instance Applicative Bag where + pure = BOne + BZero <*> _ = BZero + BOne f <*> t = f <$> t + BTwo f1 f2 <*> t = BTwo (f1 <*> t) (f2 <*> t) + BList fs <*> t = BList [f <*> t | f <- fs] + +instance Semigroup (Bag a) where (<>) = BTwo +instance Monoid (Bag a) where mempty = BZero |