diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 13 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 71 | 
3 files changed, 87 insertions, 3 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 17b0ab4..246f8fc 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -13,6 +13,8 @@  module Data.Array.Mixed where  import qualified Data.Array.RankedS as S +import qualified Data.Array.Ranked as ORB +import Data.Coerce  import Data.Kind  import Data.Proxy  import Data.Type.Equality @@ -347,3 +349,14 @@ sumOuter :: forall sh sh' a. (Storable a, Num a)  sumOuter ssh ssh'    | Refl <- lemAppNil @sh    = sumInner ssh' ssh . transpose2 ssh ssh' + +fromList :: forall n sh a. Storable a +         => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromList ssh l +  | Dict <- lemKnownINatRankSSX ssh +  , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) +  = case ssh of +      m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> +        error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ +                "does not match the type (" ++ show (natVal m) ++ ")"  +      _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index cd2dde7..9219a74 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -5,7 +5,7 @@ module Data.Array.Nested (    IxR(..),    rshape, rindex, rindexPartial, rgenerate, rsumOuter1,    rtranspose, rappend, rscalar, rfromVector, runScalar, -  rconstant, +  rconstant, rfromList,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, @@ -15,7 +15,7 @@ module Data.Array.Nested (    KnownShape(..), SShape(..),    sshape, sindex, sindexPartial, sgenerate, ssumOuter1,    stranspose, sappend, sscalar, sfromVector, sunScalar, -  sconstant, +  sconstant, sfromList,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, @@ -27,7 +27,7 @@ module Data.Array.Nested (    mconstant,    -- * Array elements -  Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2), +  Elt(mshape, mindex, mindexPartial, mscalar, mfromList, mlift, mlift2),    Primitive(..),    -- * Inductive natural numbers diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f769870..759094e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -37,7 +37,9 @@ import Control.Monad.ST  import qualified Data.Array.RankedS as S  import Data.Bifunctor (first)  import Data.Coerce (coerce, Coercible) +import Data.Foldable (toList)  import Data.Kind +import Data.List.NonEmpty (NonEmpty)  import Data.Proxy  import Data.Type.Equality  import qualified Data.Vector.Storable as VS @@ -50,6 +52,29 @@ import qualified Data.Array.Mixed as X  import Data.INat +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +--   arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +--   rshape arr == 0 ::: 0 ::: 0 ::: IZR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +--   -> no, because mlift and also any kind of internals probing from outsiders + +  type family Replicate n a where    Replicate Z a = '[]    Replicate (S n) a = a : Replicate n a @@ -105,6 +130,9 @@ newtype Primitive a = Primitive a  -- class.  type Mixed :: [Maybe Nat] -> Type -> Type  data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, an empty array may have +-- elements with nonempty sizes, but then the whole array is still empty.  newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)    deriving (Show) @@ -167,6 +195,17 @@ class Elt a where    mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a    mscalar :: a -> Mixed '[] a +  -- | All arrays in the list, even subarrays inside @a@, must have the same +  -- shape; if they do not, a runtime error will be thrown. See the +  -- documentation of 'mgenerate' for more information about this restriction. +  -- Furthermore, the length of the list must correspond with @n@: if @n@ is +  -- @Just m@ and @m@ does not equal the length of the list, a runtime error is +  -- thrown. +  -- +  -- If you want a single-dimensional array from your list, map 'mscalar' +  -- first. +  mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 a -> Mixed sh2 a @@ -215,6 +254,7 @@ instance Storable a => Elt (Primitive a) where    mindex (M_Primitive a) i = Primitive (X.index a i)    mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)    mscalar (Primitive x) = M_Primitive (X.scalar x) +  mfromList l = M_Primitive (X.fromList knownShapeX [x | M_Primitive x <- toList l])    mlift :: forall sh1 sh2.             (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -263,6 +303,8 @@ instance (Elt a, Elt b) => Elt (a, b) where    mindex (M_Tup2 a b) i = (mindex a i, mindex b i)    mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)    mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) +  mfromList l = M_Tup2 (mfromList ((\(M_Tup2 x _) -> x) <$> l)) +                       (mfromList ((\(M_Tup2 _ y) -> y) <$> l))    mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)    mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) @@ -302,6 +344,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    mscalar x = M_Nest x +  mfromList :: forall n sh. KnownShapeX (n : sh) +            => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a) +  mfromList l +    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) +    = M_Nest (mfromList ((\(M_Nest x) -> x) <$> l)) +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)          -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) @@ -546,6 +594,12 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where    mscalar (Ranked x) = M_Ranked (M_Nest x) +  mfromList :: forall m sh. KnownShapeX (m : sh) +            => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a) +  mfromList l +    | Dict <- lemKnownReplicate (Proxy @n) +    = M_Ranked (mfromList ((\(M_Ranked x) -> x) <$> l)) +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) @@ -664,6 +718,12 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where    mscalar (Shaped x) = M_Shaped (M_Nest x) +  mfromList :: forall n sh'. KnownShapeX (n : sh') +            => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a) +  mfromList l +    | Dict <- lemKnownMapJust (Proxy @sh) +    = M_Shaped (mfromList ((\(M_Shaped x) -> x) <$> l)) +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) @@ -870,6 +930,11 @@ rconstant sh x    | Dict <- lemKnownReplicate (Proxy @n)    = Ranked (mconstant (ixCvtRX sh) x) +rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a +rfromList l +  | Dict <- lemKnownReplicate (Proxy @n) +  = Ranked (mfromList ((\(Ranked x) -> x) <$> l)) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -993,3 +1058,9 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust  sconstant x    | Dict <- lemKnownMapJust (Proxy @sh)    = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) + +sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) +          => NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromList l +  | Dict <- lemKnownMapJust (Proxy @sh) +  = Shaped (mfromList ((\(Shaped x) -> x) <$> l)) | 
