From 6fc6f4327391f14f026a9848f68a28e70cef6185 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 20 Apr 2024 20:11:41 +0200 Subject: Some cleanups --- src/Data/Array/Nested/Internal.hs | 64 +++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 30 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 759094e..f76b2ab 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -206,10 +206,16 @@ class Elt a where -- first. mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a + -- | Note: this library makes no particular guarantees about the shapes of + -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the + -- full 'XArray' and as such you can distinguish different empty arrays by + -- the "shapes" of their elements. This information is meaningless, so you + -- should not use it. mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a + -- | See the documentation for 'mlift'. mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a @@ -421,7 +427,6 @@ checkBounds IZX SZX = True checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' --- Public method. Turns out this doesn't have to be in the type class! -- | Create an array given a size and a function that computes the element at a -- given index. -- @@ -487,30 +492,18 @@ mfromVector sh v munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr IZX -mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) - => IxX sh -> a -> Mixed sh a -mconstant sh x +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> a -> Mixed sh (Primitive a) +mconstantP sh x | not (checkBounds sh (knownShapeX @sh)) = error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) | otherwise = - coerce (M_Primitive (X.constant sh x)) - --- | All arrays in the list must have the same shape; if they do not, a runtime --- error will be thrown. See the documentation of 'mgenerate' for more --- information. Furthermore, the length of the list must correspond with @n@: --- if @n@ is @Just m@ and @m@ does not equal the length of the list, a runtime --- error is thrown. --- mfromList :: forall n sh a. (KnownShapeX (n : sh), Elt a) => [Mixed sh a] -> Mixed (n : sh) a --- mfromList l = --- case knownShapeX @(n : sh) of --- m@GHC_SNat :$@ _ --- | length l /= fromIntegral (natVal m) -> --- error $ "mfromList: length of list (" ++ show (length l) ++ ")" ++ --- "does not match the type (" ++ show (natVal m) ++ ")" --- | natVal m == 0 -> memptyArray _ --- -- | let shapetree = mshapeTree --- | otherwise -> _ --- () :$? _ -> _ + M_Primitive (X.constant sh x) + +-- | This 'Coercible' constraint holds for the scalar types for which 'Mixed' +-- is defined. +mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) + => IxX sh -> a -> Mixed sh a +mconstant sh x = coerce (mconstantP sh x) mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) @@ -826,7 +819,10 @@ instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) + fromInteger n = case inatSing @n of + SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) + SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ + \Rank non-zero, use explicit mconstant" deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) @@ -886,6 +882,7 @@ rgenerate sh f , Refl <- lemRankReplicate (Proxy @n) = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) +-- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. (KnownINat n2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a @@ -924,11 +921,14 @@ rfromVector sh v runScalar :: Elt a => Ranked I0 a -> a runScalar arr = rindex arr IZR +rconstantP :: forall n a. (KnownINat n, Storable a) => IxR n -> a -> Ranked n (Primitive a) +rconstantP sh x + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mconstantP (ixCvtRX sh) x) + rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) => IxR n -> a -> Ranked n a -rconstant sh x - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mconstant (ixCvtRX sh) x) +rconstant sh x = coerce (rconstantP sh x) rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a rfromList l @@ -955,7 +955,7 @@ instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) whe negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum - fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) + fromInteger n = sconstantP (fromInteger n) deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) @@ -1015,6 +1015,7 @@ sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) +-- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a @@ -1053,11 +1054,14 @@ sfromVector v sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr IZS +sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) +sconstantP x + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mconstantP (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) + sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) => a -> Shaped sh a -sconstant x - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) +sconstant x = coerce (sconstantP @sh x) sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) => NonEmpty (Shaped sh a) -> Shaped (n : sh) a -- cgit v1.2.3-70-g09d2