aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-20 20:11:41 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-20 22:08:07 +0200
commit6fc6f4327391f14f026a9848f68a28e70cef6185 (patch)
tree4a86cd3536565bc242f7048949c1210251206f89
parent18139715c7e11e7d3dbb2cf769f64c2a725832e2 (diff)
Some cleanups
-rw-r--r--src/Data/Array/Nested/Internal.hs64
1 files changed, 34 insertions, 30 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 759094e..f76b2ab 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -206,10 +206,16 @@ class Elt a where
-- first.
mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a
+ -- | Note: this library makes no particular guarantees about the shapes of
+ -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
+ -- full 'XArray' and as such you can distinguish different empty arrays by
+ -- the "shapes" of their elements. This information is meaningless, so you
+ -- should not use it.
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a
+ -- | See the documentation for 'mlift'.
mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
@@ -421,7 +427,6 @@ checkBounds IZX SZX = True
checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
--- Public method. Turns out this doesn't have to be in the type class!
-- | Create an array given a size and a function that computes the element at a
-- given index.
--
@@ -487,30 +492,18 @@ mfromVector sh v
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr IZX
-mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
- => IxX sh -> a -> Mixed sh a
-mconstant sh x
+mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> a -> Mixed sh (Primitive a)
+mconstantP sh x
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
| otherwise =
- coerce (M_Primitive (X.constant sh x))
-
--- | All arrays in the list must have the same shape; if they do not, a runtime
--- error will be thrown. See the documentation of 'mgenerate' for more
--- information. Furthermore, the length of the list must correspond with @n@:
--- if @n@ is @Just m@ and @m@ does not equal the length of the list, a runtime
--- error is thrown.
--- mfromList :: forall n sh a. (KnownShapeX (n : sh), Elt a) => [Mixed sh a] -> Mixed (n : sh) a
--- mfromList l =
--- case knownShapeX @(n : sh) of
--- m@GHC_SNat :$@ _
--- | length l /= fromIntegral (natVal m) ->
--- error $ "mfromList: length of list (" ++ show (length l) ++ ")" ++
--- "does not match the type (" ++ show (natVal m) ++ ")"
--- | natVal m == 0 -> memptyArray _
--- -- | let shapetree = mshapeTree
--- | otherwise -> _
--- () :$? _ -> _
+ M_Primitive (X.constant sh x)
+
+-- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'
+-- is defined.
+mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
+ => IxX sh -> a -> Mixed sh a
+mconstant sh x = coerce (mconstantP sh x)
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
@@ -826,7 +819,10 @@ instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n)
+ fromInteger n = case inatSing @n of
+ SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
+ SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
+ \Rank non-zero, use explicit mconstant"
deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
@@ -886,6 +882,7 @@ rgenerate sh f
, Refl <- lemRankReplicate (Proxy @n)
= Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
+-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
@@ -924,11 +921,14 @@ rfromVector sh v
runScalar :: Elt a => Ranked I0 a -> a
runScalar arr = rindex arr IZR
+rconstantP :: forall n a. (KnownINat n, Storable a) => IxR n -> a -> Ranked n (Primitive a)
+rconstantP sh x
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mconstantP (ixCvtRX sh) x)
+
rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
=> IxR n -> a -> Ranked n a
-rconstant sh x
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mconstant (ixCvtRX sh) x)
+rconstant sh x = coerce (rconstantP sh x)
rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
rfromList l
@@ -955,7 +955,7 @@ instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) whe
negate = arithPromoteShaped negate
abs = arithPromoteShaped abs
signum = arithPromoteShaped signum
- fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n)
+ fromInteger n = sconstantP (fromInteger n)
deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
@@ -1015,6 +1015,7 @@ sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))
+-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
@@ -1053,11 +1054,14 @@ sfromVector v
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr IZS
+sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)
+sconstantP x
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mconstantP (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x)
+
sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
=> a -> Shaped sh a
-sconstant x
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x)
+sconstant x = coerce (sconstantP @sh x)
sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
=> NonEmpty (Shaped sh a) -> Shaped (n : sh) a