diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-15 23:30:23 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-15 23:30:23 +0200 |
commit | a95ae379851158f48f90a0274ad74caa44f582e0 (patch) | |
tree | 0b0c15034520267f00b22c708715aa28c8f5ceda /src/Data/Array/Nested | |
parent | 690a74d571c61330978fdf5e4565ce0b8622030b (diff) |
Proper checking in *generate, plus warning in haddocks
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 129 |
1 files changed, 98 insertions, 31 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 4764165..376651f 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -32,9 +32,10 @@ module Data.Array.Nested.Internal where import Prelude hiding (mappend) -import Control.Monad (forM_) +import Control.Monad (forM_, when) import Control.Monad.ST import qualified Data.Array.RankedS as S +import Data.Bifunctor (first) import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy @@ -79,6 +80,11 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n) go SZ = Refl go (SS n) | Refl <- go n = Refl +ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IxX (sh ++ sh') -> (IxX sh, IxX sh') +ixAppSplit _ SZX idx = (IZX, idx) +ixAppSplit p (_ :$@ ssh) (i ::@ idx) = first (i ::@) (ixAppSplit p ssh idx) +ixAppSplit p (_ :$? ssh) (i ::? idx) = first (i ::?) (ixAppSplit p ssh idx) + -- | Wrapper type used as a tag to attach instances on. The instances on arrays -- of @'Primitive' a@ are more polymorphic than the direct instances for arrays @@ -137,6 +143,19 @@ data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) +-- | Tree giving the shape of every array component. +type family ShapeTree a where + ShapeTree (Primitive _) = () + ShapeTree Int = () + ShapeTree Double = () + ShapeTree () = () + + ShapeTree (a, b) = (ShapeTree a, ShapeTree b) + ShapeTree (Mixed sh a) = (IxX sh, ShapeTree a) + ShapeTree (Ranked n a) = (IxR n, ShapeTree a) + ShapeTree (Shaped sh a) = (IxS sh, ShapeTree a) + + -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. @@ -162,11 +181,11 @@ class Elt a where -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArray :: IxX sh -> Mixed sh a - -- | Return the size of the individual (SoA) arrays in this value. If @a@ - -- does not contain tuples, this coincides with the total number of scalars - -- in the given value; if @a@ contains tuples, then it is some multiple of - -- this number of scalars. - mvecsNumElts :: a -> Int + mshapeTree :: a -> ShapeTree a + + mshapeTreeZero :: Proxy a -> ShapeTree a + + mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. The shape must not have size @@ -210,7 +229,9 @@ instance Storable a => Elt (Primitive a) where = M_Primitive (f Proxy a b) memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) - mvecsNumElts _ = 1 + mshapeTree _ = () + mshapeTreeZero _ = () + mshapeTreeEq _ () () = True mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x @@ -219,7 +240,7 @@ instance Storable a => Elt (Primitive a) where :: forall sh' sh s. KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do - let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) + let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v @@ -238,7 +259,9 @@ instance (Elt a, Elt b) => Elt (a, b) where mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) + mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b)) + mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a @@ -250,15 +273,13 @@ instance (Elt a, Elt b) => Elt (a, b) where -- Arrays of arrays are just arrays, but with more dimensions. instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where + -- TODO: this is quadratic in the nesting depth because it repeatedly + -- truncates the shape vector to one a little shorter. Fix with a + -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = ixAppPrefix (knownShapeX @sh) (mshape arr) - where - ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 - ixAppPrefix SZX _ = IZX - ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx - ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx + = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) mindex (M_Nest arr) i = mindexPartial arr i @@ -300,16 +321,18 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) = f (Proxy @(sh' ++ shT)) - memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) + memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh')))) + + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) - mvecsNumElts arr = - let n = X.shapeSize (mshape arr) - in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) + mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a)) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) - (mindex example (X.zeroIdx (knownShapeX @sh'))) + (mindex example (X.zeroIxX (knownShapeX @sh'))) where sh' = mshape example @@ -336,6 +359,21 @@ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' -- Public method. Turns out this doesn't have to be in the type class! -- | Create an array given a size and a function that computes the element at a -- given index. +-- +-- **WARNING**: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- foo = mgenerate (10 ::: IZR) $ \(i ::: IZR) -> +-- mgenerate (i ::: IZR) $ \(j ::: IZR) -> +-- ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? @@ -345,17 +383,21 @@ mgenerate sh f -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. | X.shapeSize sh == 0 = memptyArray sh | otherwise = - let firstidx = X.zeroIdx' sh - firstelem = f (X.zeroIdx' sh) - in if mvecsNumElts firstelem == 0 + let firstidx = X.zeroIxX' sh + firstelem = f (X.zeroIxX' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEq (Proxy @a) shapetree (mshapeTreeZero (Proxy @a)) then memptyArray sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this feels inefficient. Should improve this. - forM_ (tail (X.enumShape sh)) $ \idx -> - mvecsWrite sh idx (f idx) vecs + -- scalar this array copying inefficient. Should improve this. + forM_ (tail (X.enumShape sh)) $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs mvecsFreeze sh vecs mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a @@ -492,9 +534,14 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ memptyArray i - mvecsNumElts (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsNumElts arr + mshapeTree (Ranked arr) + | Refl <- lemRankReplicate (Proxy @n) + , Dict <- lemKnownReplicate (Proxy @n) + = first ixCvtXR (mshapeTree arr) + + mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a)) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 mvecsUnsafeNew idx (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) @@ -597,9 +644,13 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ memptyArray i - mvecsNumElts (Shaped arr) + mshapeTree (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsNumElts arr + = first (ixCvtXS (knownShape @sh)) (mshapeTree arr) + + mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a)) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) @@ -675,8 +726,14 @@ type IxR :: INat -> Type data IxR n where IZR :: IxR Z (:::) :: Int -> IxR n -> IxR (S n) +deriving instance Show (IxR n) +deriving instance Eq (IxR n) infixr 5 ::: +zeroIxR :: SINat n -> IxR n +zeroIxR SZ = IZR +zeroIxR (SS n) = 0 ::: zeroIxR n + ixCvtXR :: IxX sh -> IxR (X.Rank sh) ixCvtXR IZX = IZR ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx @@ -702,6 +759,8 @@ rindexPartial (Ranked arr) idx = (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) (ixCvtRX idx)) +-- | **WARNING**: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. rgenerate :: forall n a. (KnownINat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a rgenerate sh f | Dict <- lemKnownReplicate (Proxy @n) @@ -787,8 +846,14 @@ type IxS :: [Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) +deriving instance Show (IxS n) +deriving instance Eq (IxS n) infixr 5 ::$ +zeroIxS :: SShape sh -> IxS sh +zeroIxS ShNil = IZS +zeroIxS (ShCons _ sh) = 0 ::$ zeroIxS sh + cvtSShapeIxS :: SShape sh -> IxS sh cvtSShapeIxS ShNil = IZS cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh @@ -815,6 +880,8 @@ sindexPartial (Shaped arr) idx = (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) +-- | **WARNING**: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) |