aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs129
1 files changed, 98 insertions, 31 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 4764165..376651f 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -32,9 +32,10 @@ module Data.Array.Nested.Internal where
import Prelude hiding (mappend)
-import Control.Monad (forM_)
+import Control.Monad (forM_, when)
import Control.Monad.ST
import qualified Data.Array.RankedS as S
+import Data.Bifunctor (first)
import Data.Coerce (coerce, Coercible)
import Data.Kind
import Data.Proxy
@@ -79,6 +80,11 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)
go SZ = Refl
go (SS n) | Refl <- go n = Refl
+ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IxX (sh ++ sh') -> (IxX sh, IxX sh')
+ixAppSplit _ SZX idx = (IZX, idx)
+ixAppSplit p (_ :$@ ssh) (i ::@ idx) = first (i ::@) (ixAppSplit p ssh idx)
+ixAppSplit p (_ :$? ssh) (i ::? idx) = first (i ::?) (ixAppSplit p ssh idx)
+
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
@@ -137,6 +143,19 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh
data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
+-- | Tree giving the shape of every array component.
+type family ShapeTree a where
+ ShapeTree (Primitive _) = ()
+ ShapeTree Int = ()
+ ShapeTree Double = ()
+ ShapeTree () = ()
+
+ ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
+ ShapeTree (Mixed sh a) = (IxX sh, ShapeTree a)
+ ShapeTree (Ranked n a) = (IxR n, ShapeTree a)
+ ShapeTree (Shaped sh a) = (IxS sh, ShapeTree a)
+
+
-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
@@ -162,11 +181,11 @@ class Elt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArray :: IxX sh -> Mixed sh a
- -- | Return the size of the individual (SoA) arrays in this value. If @a@
- -- does not contain tuples, this coincides with the total number of scalars
- -- in the given value; if @a@ contains tuples, then it is some multiple of
- -- this number of scalars.
- mvecsNumElts :: a -> Int
+ mshapeTree :: a -> ShapeTree a
+
+ mshapeTreeZero :: Proxy a -> ShapeTree a
+
+ mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents. The shape must not have size
@@ -210,7 +229,9 @@ instance Storable a => Elt (Primitive a) where
= M_Primitive (f Proxy a b)
memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
- mvecsNumElts _ = 1
+ mshapeTree _ = ()
+ mshapeTreeZero _ = ()
+ mshapeTreeEq _ () () = True
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
@@ -219,7 +240,7 @@ instance Storable a => Elt (Primitive a) where
:: forall sh' sh s. KnownShapeX sh'
=> IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
- let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
+ let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))
VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v
@@ -238,7 +259,9 @@ instance (Elt a, Elt b) => Elt (a, b) where
mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
- mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
+ mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
+ mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b))
+ mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
@@ -250,15 +273,13 @@ instance (Elt a, Elt b) => Elt (a, b) where
-- Arrays of arrays are just arrays, but with more dimensions.
instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting depth because it repeatedly
+ -- truncates the shape vector to one a little shorter. Fix with a
+ -- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
- = ixAppPrefix (knownShapeX @sh) (mshape arr)
- where
- ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
- ixAppPrefix SZX _ = IZX
- ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
- ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
+ = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
mindex (M_Nest arr) i = mindexPartial arr i
@@ -300,16 +321,18 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
, Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
= f (Proxy @(sh' ++ shT))
- memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
+ memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh'))))
+
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh'))))
- mvecsNumElts arr =
- let n = X.shapeSize (mshape arr)
- in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
+ mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
mvecsUnsafeNew sh example
| X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
| otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
- (mindex example (X.zeroIdx (knownShapeX @sh')))
+ (mindex example (X.zeroIxX (knownShapeX @sh')))
where
sh' = mshape example
@@ -336,6 +359,21 @@ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
-- Public method. Turns out this doesn't have to be in the type class!
-- | Create an array given a size and a function that computes the element at a
-- given index.
+--
+-- **WARNING**: It is required that every @a@ returned by the argument to
+-- 'mgenerate' has the same shape. For example, the following will throw a
+-- runtime error:
+--
+-- foo :: Mixed [Nothing] (Mixed [Nothing] Double)
+-- foo = mgenerate (10 ::: IZR) $ \(i ::: IZR) ->
+-- mgenerate (i ::: IZR) $ \(j ::: IZR) ->
+-- ...
+--
+-- because the size of the inner 'mgenerate' is not always the same (it depends
+-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
+-- the entire hierarchy (after distributing out tuples) must be a rectangular
+-- array. The type of 'mgenerate' allows this requirement to be broken very
+-- easily, hence the runtime check.
mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
-- TODO: Do we need this checkBounds check elsewhere as well?
@@ -345,17 +383,21 @@ mgenerate sh f
-- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
| X.shapeSize sh == 0 = memptyArray sh
| otherwise =
- let firstidx = X.zeroIdx' sh
- firstelem = f (X.zeroIdx' sh)
- in if mvecsNumElts firstelem == 0
+ let firstidx = X.zeroIxX' sh
+ firstelem = f (X.zeroIxX' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEq (Proxy @a) shapetree (mshapeTreeZero (Proxy @a))
then memptyArray sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
-- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this feels inefficient. Should improve this.
- forM_ (tail (X.enumShape sh)) $ \idx ->
- mvecsWrite sh idx (f idx) vecs
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ (tail (X.enumShape sh)) $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
@@ -492,9 +534,14 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
memptyArray i
- mvecsNumElts (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsNumElts arr
+ mshapeTree (Ranked arr)
+ | Refl <- lemRankReplicate (Proxy @n)
+ , Dict <- lemKnownReplicate (Proxy @n)
+ = first ixCvtXR (mshapeTree arr)
+
+ mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
@@ -597,9 +644,13 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
memptyArray i
- mvecsNumElts (Shaped arr)
+ mshapeTree (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsNumElts arr
+ = first (ixCvtXS (knownShape @sh)) (mshapeTree arr)
+
+ mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
@@ -675,8 +726,14 @@ type IxR :: INat -> Type
data IxR n where
IZR :: IxR Z
(:::) :: Int -> IxR n -> IxR (S n)
+deriving instance Show (IxR n)
+deriving instance Eq (IxR n)
infixr 5 :::
+zeroIxR :: SINat n -> IxR n
+zeroIxR SZ = IZR
+zeroIxR (SS n) = 0 ::: zeroIxR n
+
ixCvtXR :: IxX sh -> IxR (X.Rank sh)
ixCvtXR IZX = IZR
ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx
@@ -702,6 +759,8 @@ rindexPartial (Ranked arr) idx =
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
(ixCvtRX idx))
+-- | **WARNING**: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
rgenerate :: forall n a. (KnownINat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a
rgenerate sh f
| Dict <- lemKnownReplicate (Proxy @n)
@@ -787,8 +846,14 @@ type IxS :: [Nat] -> Type
data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
+deriving instance Show (IxS n)
+deriving instance Eq (IxS n)
infixr 5 ::$
+zeroIxS :: SShape sh -> IxS sh
+zeroIxS ShNil = IZS
+zeroIxS (ShCons _ sh) = 0 ::$ zeroIxS sh
+
cvtSShapeIxS :: SShape sh -> IxS sh
cvtSShapeIxS ShNil = IZS
cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh
@@ -815,6 +880,8 @@ sindexPartial (Shaped arr) idx =
(rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
+-- | **WARNING**: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a
sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)