aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 13:03:53 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 13:03:53 +0200
commit70d2edeb338c6acbe9943c4f8b24225bcb912211 (patch)
tree3c6570d2a9f8eb9f89b16de092e21ad516b9349c /src/Data/Array/Nested
parent478875e9d82e8c645cbea2e41362c312e892488a (diff)
Num instances for Mixed, Ranked, Shaped
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs72
1 files changed, 72 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 86b0fce..f7c383a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where
import Control.Monad (forM_)
import Control.Monad.ST
+import qualified Data.Array.RankedS as S
import Data.Coerce (coerce, Coercible)
import Data.Kind
import Data.Proxy
@@ -353,6 +354,33 @@ mtranspose perm =
mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
(X.transpose perm))
+mliftPrim :: (KnownShapeX sh, Storable a)
+ => (a -> a)
+ -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
+mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: (KnownShapeX sh, Storable a)
+ => (a -> a -> a)
+ -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
+mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) =
+ M_Primitive (X.XArray (S.zipWithA f arr1 arr2))
+
+instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where
+ (+) = mliftPrim2 (+)
+ (-) = mliftPrim2 (-)
+ (*) = mliftPrim2 (*)
+ negate = mliftPrim negate
+ abs = mliftPrim abs
+ signum = mliftPrim signum
+ fromInteger n =
+ case X.ssxToShape' (knownShapeX @sh) of
+ Just sh -> M_Primitive (X.constant sh (fromInteger n))
+ Nothing -> error "Data.Array.Nested.fromIntegral: \
+ \Unknown components in shape, use explicit replicate"
+
+deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)
+deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double)
+
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'INat'.
@@ -582,6 +610,28 @@ coerceMixedXArray = coerce
-- ====== API OF RANKED ARRAYS ====== --
+arithPromoteRanked :: forall n a. KnownINat n
+ => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a
+arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
+
+arithPromoteRanked2 :: forall n a. KnownINat n
+ => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a -> Ranked n a
+arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
+
+instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+ (+) = arithPromoteRanked2 (+)
+ (-) = arithPromoteRanked2 (-)
+ (*) = arithPromoteRanked2 (*)
+ negate = arithPromoteRanked negate
+ abs = arithPromoteRanked abs
+ signum = arithPromoteRanked signum
+ fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n)
+
+deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+
-- | An index into a rank-typed array.
type IxR :: INat -> Type
data IxR n where
@@ -646,6 +696,28 @@ rtranspose perm (Ranked arr)
-- ====== API OF SHAPED ARRAYS ====== --
+arithPromoteShaped :: forall sh a. KnownShape sh
+ => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a
+arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+
+arithPromoteShaped2 :: forall sh a. KnownShape sh
+ => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a -> Shaped sh a
+arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+
+instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where
+ (+) = arithPromoteShaped2 (+)
+ (-) = arithPromoteShaped2 (-)
+ (*) = arithPromoteShaped2 (*)
+ negate = arithPromoteShaped negate
+ abs = arithPromoteShaped abs
+ signum = arithPromoteShaped signum
+ fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n)
+
+deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
+deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
+
-- | An index into a shape-typed array.
--
-- For convenience, this contains regular 'Int's instead of bounded integers