diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 54 |
1 files changed, 25 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 3bfda19..ca90889 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -24,9 +24,6 @@ import Control.DeepSeq (NFData(..)) import Control.Monad (forM_, when) import Control.Monad.ST import Data.Array.RankedS qualified as S -import Data.Array.Internal.RankedS qualified as Internal.RankedS -import Data.Array.Internal.RankedG qualified as Internal.RankedG -import Data.Array.Internal qualified as Array.Internal import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) @@ -52,6 +49,8 @@ import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.Permutation import Data.Array.Mixed.Lemmas +import Data.Bag + -- TODO: -- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -191,19 +190,24 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) +showsMixedArray :: (Show a, Elt a) + => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ + -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ + -> Int -> Mixed sh a -> ShowS +showsMixedArray fromlistPrefix replicatePrefix d arr = + showParen (d > 10) $ + -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here + case mtoListLinear arr of + hd : _ : _ + | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> + showString replicatePrefix . showString " " . showsPrec 11 hd + _ -> + showString fromlistPrefix . showString " " . shows (mtoListLinear arr) + instance (Show a, Elt a) => Show (Mixed sh a) where - showsPrec d arr = showParen (d > 10) $ - let defaultResult = - -- TODO: to avoid ambiguity, this should type-apply the shape to mfromListLinear - showString "mfromListLinear " . shows (shxToList (mshape arr)) . showString " " - . shows (mtoListLinear arr) - in if stridesAreZero (shxLength $ mshape arr) (mstrideTree arr) - then case mtoListLinear arr of - [] -> defaultResult - [_] -> defaultResult - hd : _ -> showString "mreplicate " . shows (shxToList (mshape arr)) . showString " " - . showsPrec 11 hd - else defaultResult + showsPrec d arr = + let sh = show (shxToList (mshape arr)) + in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr instance Elt a => NFData (Mixed sh a) where rnf = mrnf @@ -269,16 +273,6 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) -data StrideTree = - StrideLeaf [Int] - | StrideNode StrideTree StrideTree - -stridesAreZero :: Int -> StrideTree -> Bool -stridesAreZero prefixLen (StrideLeaf ss) = - all (== 0) (take prefixLen ss) -stridesAreZero prefixLen (StrideNode ss1 ss2) = - stridesAreZero prefixLen ss1 && stridesAreZero prefixLen ss2 - -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. @@ -351,7 +345,9 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String - mstrideTree :: Mixed sh a -> StrideTree + -- | Returns the stride vector of each underlying component array making up + -- this mixed array. + marrayStrides :: Mixed sh a -> Bag [Int] -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. @@ -449,7 +445,7 @@ instance Storable a => Elt (Primitive a) where mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" - mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss + marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal @@ -523,7 +519,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b) + marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -656,7 +652,7 @@ instance Elt a => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mstrideTree (M_Nest _ arr) = mstrideTree arr + marrayStrides (M_Nest _ arr) = marrayStrides arr mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs |