diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 62 | 
1 files changed, 33 insertions, 29 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 759094e..f76b2ab 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -206,10 +206,16 @@ class Elt a where    -- first.    mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a +  -- | Note: this library makes no particular guarantees about the shapes of +  -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the +  -- full 'XArray' and as such you can distinguish different empty arrays by +  -- the "shapes" of their elements. This information is meaningless, so you +  -- should not use it.    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 a -> Mixed sh2 a +  -- | See the documentation for 'mlift'.    mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)           => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)           -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a @@ -421,7 +427,6 @@ checkBounds IZX SZX = True  checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'  checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' --- Public method. Turns out this doesn't have to be in the type class!  -- | Create an array given a size and a function that computes the element at a  -- given index.  -- @@ -487,30 +492,18 @@ mfromVector sh v  munScalar :: Elt a => Mixed '[] a -> a  munScalar arr = mindex arr IZX -mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) -          => IxX sh -> a -> Mixed sh a -mconstant sh x +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> a -> Mixed sh (Primitive a) +mconstantP sh x    | not (checkBounds sh (knownShapeX @sh)) =        error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)    | otherwise = -      coerce (M_Primitive (X.constant sh x)) +      M_Primitive (X.constant sh x) --- | All arrays in the list must have the same shape; if they do not, a runtime --- error will be thrown. See the documentation of 'mgenerate' for more --- information. Furthermore, the length of the list must correspond with @n@: --- if @n@ is @Just m@ and @m@ does not equal the length of the list, a runtime --- error is thrown. --- mfromList :: forall n sh a. (KnownShapeX (n : sh), Elt a) => [Mixed sh a] -> Mixed (n : sh) a --- mfromList l = ---   case knownShapeX @(n : sh) of ---     m@GHC_SNat :$@ _ ---       | length l /= fromIntegral (natVal m) -> ---           error $ "mfromList: length of list (" ++ show (length l) ++ ")" ++ ---                   "does not match the type (" ++ show (natVal m) ++ ")" ---       | natVal m == 0 -> memptyArray _ ---       -- | let shapetree = mshapeTree ---       | otherwise -> _ ---     () :$? _ -> _ +-- | This 'Coercible' constraint holds for the scalar types for which 'Mixed' +-- is defined. +mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) +          => IxX sh -> a -> Mixed sh a +mconstant sh x = coerce (mconstantP sh x)  mliftPrim :: (KnownShapeX sh, Storable a)            => (a -> a) @@ -826,7 +819,10 @@ instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where    negate = arithPromoteRanked negate    abs = arithPromoteRanked abs    signum = arithPromoteRanked signum -  fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) +  fromInteger n = case inatSing @n of +                    SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) +                    SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ +                                  \Rank non-zero, use explicit mconstant"  deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)  deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) @@ -886,6 +882,7 @@ rgenerate sh f    , Refl <- lemRankReplicate (Proxy @n)    = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) +-- | See the documentation of 'mlift'.  rlift :: forall n1 n2 a. (KnownINat n2, Elt a)        => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)        -> Ranked n1 a -> Ranked n2 a @@ -924,11 +921,14 @@ rfromVector sh v  runScalar :: Elt a => Ranked I0 a -> a  runScalar arr = rindex arr IZR +rconstantP :: forall n a. (KnownINat n, Storable a) => IxR n -> a -> Ranked n (Primitive a) +rconstantP sh x +  | Dict <- lemKnownReplicate (Proxy @n) +  = Ranked (mconstantP (ixCvtRX sh) x) +  rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))            => IxR n -> a -> Ranked n a -rconstant sh x -  | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mconstant (ixCvtRX sh) x) +rconstant sh x = coerce (rconstantP sh x)  rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a  rfromList l @@ -955,7 +955,7 @@ instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) whe    negate = arithPromoteShaped negate    abs = arithPromoteShaped abs    signum = arithPromoteShaped signum -  fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) +  fromInteger n = sconstantP (fromInteger n)  deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)  deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) @@ -1015,6 +1015,7 @@ sgenerate f    | Dict <- lemKnownMapJust (Proxy @sh)    = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) +-- | See the documentation of 'mlift'.  slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)        => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)        -> Shaped sh1 a -> Shaped sh2 a @@ -1053,11 +1054,14 @@ sfromVector v  sunScalar :: Elt a => Shaped '[] a -> a  sunScalar arr = sindex arr IZS +sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) +sconstantP x +  | Dict <- lemKnownMapJust (Proxy @sh) +  = Shaped (mconstantP (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) +  sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))            => a -> Shaped sh a -sconstant x -  | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) +sconstant x = coerce (sconstantP @sh x)  sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)            => NonEmpty (Shaped sh a) -> Shaped (n : sh) a | 
