diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 17 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 129 | 
2 files changed, 107 insertions, 39 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 9a8ccfd..17b0ab4 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -49,6 +49,7 @@ data IxX sh where    (::@) :: Int -> IxX sh -> IxX (Just n : sh)    (::?) :: Int -> IxX sh -> IxX (Nothing : sh)  deriving instance Show (IxX sh) +deriving instance Eq (IxX sh)  infixr 5 ::@  infixr 5 ::? @@ -81,15 +82,15 @@ type XArray :: [Maybe Nat] -> Type -> Type  newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)    deriving (Show) -zeroIdx :: StaticShapeX sh -> IxX sh -zeroIdx SZX = IZX -zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh -zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh +zeroIxX :: StaticShapeX sh -> IxX sh +zeroIxX SZX = IZX +zeroIxX (_ :$@ ssh) = 0 ::@ zeroIxX ssh +zeroIxX (_ :$? ssh) = 0 ::? zeroIxX ssh -zeroIdx' :: IxX sh -> IxX sh -zeroIdx' IZX = IZX -zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh -zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh +zeroIxX' :: IxX sh -> IxX sh +zeroIxX' IZX = IZX +zeroIxX' (_ ::@ sh) = 0 ::@ zeroIxX' sh +zeroIxX' (_ ::? sh) = 0 ::? zeroIxX' sh  ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh')  ixAppend IZX idx' = idx' diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 4764165..376651f 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -32,9 +32,10 @@ module Data.Array.Nested.Internal where  import Prelude hiding (mappend) -import Control.Monad (forM_) +import Control.Monad (forM_, when)  import Control.Monad.ST  import qualified Data.Array.RankedS as S +import Data.Bifunctor (first)  import Data.Coerce (coerce, Coercible)  import Data.Kind  import Data.Proxy @@ -79,6 +80,11 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)      go SZ = Refl      go (SS n) | Refl <- go n = Refl +ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IxX (sh ++ sh') -> (IxX sh, IxX sh') +ixAppSplit _ SZX idx = (IZX, idx) +ixAppSplit p (_ :$@ ssh) (i ::@ idx) = first (i ::@) (ixAppSplit p ssh idx) +ixAppSplit p (_ :$? ssh) (i ::? idx) = first (i ::?) (ixAppSplit p ssh idx) +  -- | Wrapper type used as a tag to attach instances on. The instances on arrays  -- of @'Primitive' a@ are more polymorphic than the direct instances for arrays @@ -137,6 +143,19 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh  data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) +-- | Tree giving the shape of every array component. +type family ShapeTree a where +  ShapeTree (Primitive _) = () +  ShapeTree Int = () +  ShapeTree Double = () +  ShapeTree () = () + +  ShapeTree (a, b) = (ShapeTree a, ShapeTree b) +  ShapeTree (Mixed sh a) = (IxX sh, ShapeTree a) +  ShapeTree (Ranked n a) = (IxR n, ShapeTree a) +  ShapeTree (Shaped sh a) = (IxS sh, ShapeTree a) + +  -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or  -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'  -- a@; see the documentation for 'Primitive' for more details. @@ -162,11 +181,11 @@ class Elt a where    -- | Create an empty array. The given shape must have size zero; this may or may not be checked.    memptyArray :: IxX sh -> Mixed sh a -  -- | Return the size of the individual (SoA) arrays in this value. If @a@ -  -- does not contain tuples, this coincides with the total number of scalars -  -- in the given value; if @a@ contains tuples, then it is some multiple of -  -- this number of scalars. -  mvecsNumElts :: a -> Int +  mshapeTree :: a -> ShapeTree a + +  mshapeTreeZero :: Proxy a -> ShapeTree a + +  mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool    -- | Create uninitialised vectors for this array type, given the shape of    -- this vector and an example for the contents. The shape must not have size @@ -210,7 +229,9 @@ instance Storable a => Elt (Primitive a) where      = M_Primitive (f Proxy a b)    memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) -  mvecsNumElts _ = 1 +  mshapeTree _ = () +  mshapeTreeZero _ = () +  mshapeTreeEq _ () () = True    mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)    mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x @@ -219,7 +240,7 @@ instance Storable a => Elt (Primitive a) where      :: forall sh' sh s. KnownShapeX sh'      => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()    mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do -    let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) +    let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))      VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)    mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v @@ -238,7 +259,9 @@ instance (Elt a, Elt b) => Elt (a, b) where    mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)    memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) -  mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y +  mshapeTree (x, y) = (mshapeTree x, mshapeTree y) +  mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b)) +  mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'    mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y    mvecsWrite sh i (x, y) (MV_Tup2 a b) = do      mvecsWrite sh i x a @@ -250,15 +273,13 @@ instance (Elt a, Elt b) => Elt (a, b) where  -- Arrays of arrays are just arrays, but with more dimensions.  instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where +  -- TODO: this is quadratic in the nesting depth because it repeatedly +  -- truncates the shape vector to one a little shorter. Fix with a +  -- moverlongShape method, a prefix of which is mshape.    mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh    mshape (M_Nest arr)      | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') -    = ixAppPrefix (knownShapeX @sh) (mshape arr) -    where -      ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 -      ixAppPrefix SZX _ = IZX -      ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx -      ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx +    = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))    mindex (M_Nest arr) i = mindexPartial arr i @@ -300,16 +321,18 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where          , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))          = f (Proxy @(sh' ++ shT)) -  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) +  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh')))) + +  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) -  mvecsNumElts arr = -    let n = X.shapeSize (mshape arr) -    in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) +  mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a)) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2    mvecsUnsafeNew sh example      | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"      | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) -                                                 (mindex example (X.zeroIdx (knownShapeX @sh'))) +                                                 (mindex example (X.zeroIxX (knownShapeX @sh')))      where        sh' = mshape example @@ -336,6 +359,21 @@ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'  -- Public method. Turns out this doesn't have to be in the type class!  -- | Create an array given a size and a function that computes the element at a  -- given index. +-- +-- **WARNING**: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +--   foo :: Mixed [Nothing] (Mixed [Nothing] Double) +--   foo = mgenerate (10 ::: IZR) $ \(i ::: IZR) -> +--           mgenerate (i ::: IZR) $ \(j ::: IZR) -> +--             ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check.  mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a  mgenerate sh f    -- TODO: Do we need this checkBounds check elsewhere as well? @@ -345,17 +383,21 @@ mgenerate sh f    -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.    | X.shapeSize sh == 0 = memptyArray sh    | otherwise = -      let firstidx = X.zeroIdx' sh -          firstelem = f (X.zeroIdx' sh) -      in if mvecsNumElts firstelem == 0 +      let firstidx = X.zeroIxX' sh +          firstelem = f (X.zeroIxX' sh) +          shapetree = mshapeTree firstelem +      in if mshapeTreeEq (Proxy @a) shapetree (mshapeTreeZero (Proxy @a))             then memptyArray sh             else runST $ do                    vecs <- mvecsUnsafeNew sh firstelem                    mvecsWrite sh firstidx firstelem vecs                    -- TODO: This is likely fine if @a@ is big, but if @a@ is a -                  -- scalar this feels inefficient. Should improve this. -                  forM_ (tail (X.enumShape sh)) $ \idx -> -                    mvecsWrite sh idx (f idx) vecs +                  -- scalar this array copying inefficient. Should improve this. +                  forM_ (tail (X.enumShape sh)) $ \idx -> do +                    let val = f idx +                    when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ +                      error "Data.Array.Nested mgenerate: generated values do not have equal shapes" +                    mvecsWrite sh idx val vecs                    mvecsFreeze sh vecs  mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a @@ -492,9 +534,14 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where      = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $          memptyArray i -  mvecsNumElts (Ranked arr) -    | Dict <- lemKnownReplicate (Proxy @n) -    = mvecsNumElts arr +  mshapeTree (Ranked arr) +    | Refl <- lemRankReplicate (Proxy @n) +    , Dict <- lemKnownReplicate (Proxy @n) +    = first ixCvtXR (mshapeTree arr) + +  mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a)) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2    mvecsUnsafeNew idx (Ranked arr)      | Dict <- lemKnownReplicate (Proxy @n) @@ -597,9 +644,13 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where      = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $          memptyArray i -  mvecsNumElts (Shaped arr) +  mshapeTree (Shaped arr)      | Dict <- lemKnownMapJust (Proxy @sh) -    = mvecsNumElts arr +    = first (ixCvtXS (knownShape @sh)) (mshapeTree arr) + +  mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a)) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2    mvecsUnsafeNew idx (Shaped arr)      | Dict <- lemKnownMapJust (Proxy @sh) @@ -675,8 +726,14 @@ type IxR :: INat -> Type  data IxR n where    IZR :: IxR Z    (:::) :: Int -> IxR n -> IxR (S n) +deriving instance Show (IxR n) +deriving instance Eq (IxR n)  infixr 5 ::: +zeroIxR :: SINat n -> IxR n +zeroIxR SZ = IZR +zeroIxR (SS n) = 0 ::: zeroIxR n +  ixCvtXR :: IxX sh -> IxR (X.Rank sh)  ixCvtXR IZX = IZR  ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx @@ -702,6 +759,8 @@ rindexPartial (Ranked arr) idx =              (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)              (ixCvtRX idx)) +-- | **WARNING**: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details.  rgenerate :: forall n a. (KnownINat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a  rgenerate sh f    | Dict <- lemKnownReplicate (Proxy @n) @@ -787,8 +846,14 @@ type IxS :: [Nat] -> Type  data IxS sh where    IZS :: IxS '[]    (::$) :: Int -> IxS sh -> IxS (n : sh) +deriving instance Show (IxS n) +deriving instance Eq (IxS n)  infixr 5 ::$ +zeroIxS :: SShape sh -> IxS sh +zeroIxS ShNil = IZS +zeroIxS (ShCons _ sh) = 0 ::$ zeroIxS sh +  cvtSShapeIxS :: SShape sh -> IxS sh  cvtSShapeIxS ShNil = IZS  cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh @@ -815,6 +880,8 @@ sindexPartial (Shaped arr) idx =              (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)              (ixCvtSX idx)) +-- | **WARNING**: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details.  sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a  sgenerate f    | Dict <- lemKnownMapJust (Proxy @sh) | 
