diff options
| -rw-r--r-- | ox-arrays.cabal | 1 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 54 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 16 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 18 | ||||
| -rw-r--r-- | src/Data/Bag.hs | 18 | 
6 files changed, 62 insertions, 54 deletions
| diff --git a/ox-arrays.cabal b/ox-arrays.cabal index f00e6cb..ea7d75f 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -49,6 +49,7 @@ library      Data.Array.Nested.Internal.Ranked      Data.Array.Nested.Internal.Shape      Data.Array.Nested.Internal.Shaped +    Data.Bag    if flag(trace-wrappers)      exposed-modules: diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index 91a11ed..93484dc 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -15,6 +15,9 @@  module Data.Array.Mixed.XArray where  import Control.DeepSeq (NFData) +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS +import Data.Array.Internal qualified as OI  import Data.Array.Ranked qualified as ORB  import Data.Array.RankedS qualified as S  import Data.Coerce @@ -59,6 +62,12 @@ fromVector sh v  toVector :: Storable a => XArray sh a -> VS.Vector a  toVector (XArray arr) = S.toVector arr +-- | This allows observing the strides in the underlying orthotope array. This +-- can be useful for optimisation, but should be considered an implementation +-- detail: strides may change in new versions of this library without notice. +arrayStrides :: XArray sh a -> [Int] +arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides +  scalar :: Storable a => a -> XArray '[] a  scalar = XArray . S.scalar diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 3bfda19..ca90889 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -24,9 +24,6 @@ import Control.DeepSeq (NFData(..))  import Control.Monad (forM_, when)  import Control.Monad.ST  import Data.Array.RankedS qualified as S -import Data.Array.Internal.RankedS qualified as Internal.RankedS -import Data.Array.Internal.RankedG qualified as Internal.RankedG -import Data.Array.Internal qualified as Array.Internal  import Data.Bifunctor (bimap)  import Data.Coerce  import Data.Foldable (toList) @@ -52,6 +49,8 @@ import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Lemmas +import Data.Bag +  -- TODO:  --   sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -191,19 +190,24 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s  data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) +showsMixedArray :: (Show a, Elt a) +                => String  -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ +                -> String  -- ^ replicate prefix: e.g. @rreplicate [2,3]@ +                -> Int -> Mixed sh a -> ShowS +showsMixedArray fromlistPrefix replicatePrefix d arr = +  showParen (d > 10) $ +    -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here +    case mtoListLinear arr of +      hd : _ : _ +        | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> +            showString replicatePrefix . showString " " . showsPrec 11 hd +      _ -> +       showString fromlistPrefix . showString " " . shows (mtoListLinear arr) +  instance (Show a, Elt a) => Show (Mixed sh a) where -  showsPrec d arr = showParen (d > 10) $ -    let defaultResult = -          -- TODO: to avoid ambiguity, this should type-apply the shape to mfromListLinear -          showString "mfromListLinear " . shows (shxToList (mshape arr)) . showString " " -            . shows (mtoListLinear arr) -    in if stridesAreZero (shxLength $ mshape arr) (mstrideTree arr) -       then case mtoListLinear arr of -         [] -> defaultResult -         [_] -> defaultResult -         hd : _ -> showString "mreplicate " . shows (shxToList (mshape arr)) . showString " " -                     . showsPrec 11 hd -       else defaultResult +  showsPrec d arr = +    let sh = show (shxToList (mshape arr)) +    in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr  instance Elt a => NFData (Mixed sh a) where    rnf = mrnf @@ -269,16 +273,6 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem)  matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a  matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) -data StrideTree = -    StrideLeaf [Int] -  | StrideNode StrideTree StrideTree - -stridesAreZero :: Int -> StrideTree -> Bool -stridesAreZero prefixLen (StrideLeaf ss) = -  all (== 0) (take prefixLen ss) -stridesAreZero prefixLen (StrideNode ss1 ss2) = -  stridesAreZero prefixLen ss1 && stridesAreZero prefixLen ss2 -  -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or  -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'  -- a@; see the documentation for 'Primitive' for more details. @@ -351,7 +345,9 @@ class Elt a where    mshowShapeTree :: Proxy a -> ShapeTree a -> String -  mstrideTree :: Mixed sh a -> StrideTree +  -- | Returns the stride vector of each underlying component array making up +  -- this mixed array. +  marrayStrides :: Mixed sh a -> Bag [Int]    -- | Given the shape of this array, an index and a value, write the value at    -- that index in the vectors. @@ -449,7 +445,7 @@ instance Storable a => Elt (Primitive a) where    mshapeTreeEq _ () () = True    mshapeTreeEmpty _ () = False    mshowShapeTree _ () = "()" -  mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss +  marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)    mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x    -- TODO: this use of toVector is suboptimal @@ -523,7 +519,7 @@ instance (Elt a, Elt b) => Elt (a, b) where    mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'    mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2    mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" -  mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b) +  marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b    mvecsWrite sh i (x, y) (MV_Tup2 a b) = do      mvecsWrite sh i x a      mvecsWrite sh i y b @@ -656,7 +652,7 @@ instance Elt a => Elt (Mixed sh' a) where    mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" -  mstrideTree (M_Nest _ arr) = mstrideTree arr +  marrayStrides (M_Nest _ arr) = marrayStrides arr    mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index cb8aae0..2aba1bc 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -65,17 +65,9 @@ deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)  deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)  instance (Show a, Elt a) => Show (Ranked n a) where -  showsPrec d arr@(Ranked marr) = showParen (d > 10) $ -    let defaultResult = -          showString "rfromListLinear " . shows (toList (rshape arr)) . showString " " -            . shows (rtoListLinear arr) -    in if stridesAreZero (shxLength $ mshape marr) (mstrideTree marr) -       then case rtoListLinear arr of -         [] -> defaultResult -         [_] -> defaultResult -         hd : _ -> showString "rreplicate " . shows (toList (rshape arr)) . showString " " -                     . showsPrec 11 hd -       else defaultResult +  showsPrec d arr@(Ranked marr) = +    let sh = show (toList (rshape arr)) +    in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr  instance Elt a => NFData (Ranked n a) where    rnf (Ranked arr) = rnf arr @@ -152,7 +144,7 @@ instance Elt a => Elt (Ranked n a) where    mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" -  mstrideTree (M_Ranked arr) = mstrideTree arr +  marrayStrides (M_Ranked arr) = marrayStrides arr    mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()    mvecsWrite sh idx (Ranked arr) vecs = diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d5c9612..b7cb14d 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -65,18 +65,10 @@ newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)  deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)  deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) -instance (Show a, Elt a) => Show (Shaped sh a) where -  showsPrec d arr@(Shaped marr) = showParen (d > 10) $ -    let defaultResult = -          showString "sfromListLinear " . shows (shsToList (sshape arr)) . showString " " -            . shows (stoListLinear arr) -    in if stridesAreZero (shxLength $ mshape marr) (mstrideTree marr) -       then case stoListLinear arr of -         [] -> defaultResult -         [_] -> defaultResult -         hd : _ -> showString "sreplicate " . shows (shsToList (sshape arr)) . showString " " -                     . showsPrec 11 hd -       else defaultResult +instance (Show a, Elt a) => Show (Shaped n a) where +  showsPrec d arr@(Shaped marr) = +    let sh = show (shsToList (sshape arr)) +    in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr  instance Elt a => NFData (Shaped sh a) where    rnf (Shaped arr) = rnf arr @@ -150,7 +142,7 @@ instance Elt a => Elt (Shaped sh a) where    mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" -  mstrideTree (M_Shaped arr) = mstrideTree arr +  marrayStrides (M_Shaped arr) = marrayStrides arr    mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()    mvecsWrite sh idx (Shaped arr) vecs = diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs new file mode 100644 index 0000000..84c770a --- /dev/null +++ b/src/Data/Bag.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE DeriveTraversable #-} +module Data.Bag where + + +-- | An ordered sequence that can be folded over. +data Bag a = BZero | BOne a | BTwo (Bag a) (Bag a) | BList [Bag a] +  deriving (Functor, Foldable, Traversable) + +-- Really only here for 'pure' +instance Applicative Bag where +  pure = BOne +  BZero <*> _ = BZero +  BOne f <*> t = f <$> t +  BTwo f1 f2 <*> t = BTwo (f1 <*> t) (f2 <*> t) +  BList fs <*> t = BList [f <*> t | f <- fs] + +instance Semigroup (Bag a) where (<>) = BTwo +instance Monoid (Bag a) where mempty = BZero | 
