diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 13:03:53 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 13:03:53 +0200 |
commit | 70d2edeb338c6acbe9943c4f8b24225bcb912211 (patch) | |
tree | 3c6570d2a9f8eb9f89b16de092e21ad516b9349c /src/Data/Array/Nested | |
parent | 478875e9d82e8c645cbea2e41362c312e892488a (diff) |
Num instances for Mixed, Ranked, Shaped
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 86b0fce..f7c383a 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where import Control.Monad (forM_) import Control.Monad.ST +import qualified Data.Array.RankedS as S import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy @@ -353,6 +354,33 @@ mtranspose perm = mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') (X.transpose perm)) +mliftPrim :: (KnownShapeX sh, Storable a) + => (a -> a) + -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (KnownShapeX sh, Storable a) + => (a -> a -> a) + -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = + M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) + +instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where + (+) = mliftPrim2 (+) + (-) = mliftPrim2 (-) + (*) = mliftPrim2 (*) + negate = mliftPrim negate + abs = mliftPrim abs + signum = mliftPrim signum + fromInteger n = + case X.ssxToShape' (knownShapeX @sh) of + Just sh -> M_Primitive (X.constant sh (fromInteger n)) + Nothing -> error "Data.Array.Nested.fromIntegral: \ + \Unknown components in shape, use explicit replicate" + +deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) +deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) + -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'INat'. @@ -582,6 +610,28 @@ coerceMixedXArray = coerce -- ====== API OF RANKED ARRAYS ====== -- +arithPromoteRanked :: forall n a. KnownINat n + => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) + -> Ranked n a -> Ranked n a +arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce + +arithPromoteRanked2 :: forall n a. KnownINat n + => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) + -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce + +instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where + (+) = arithPromoteRanked2 (+) + (-) = arithPromoteRanked2 (-) + (*) = arithPromoteRanked2 (*) + negate = arithPromoteRanked negate + abs = arithPromoteRanked abs + signum = arithPromoteRanked signum + fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) + +deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) +deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) + -- | An index into a rank-typed array. type IxR :: INat -> Type data IxR n where @@ -646,6 +696,28 @@ rtranspose perm (Ranked arr) -- ====== API OF SHAPED ARRAYS ====== -- +arithPromoteShaped :: forall sh a. KnownShape sh + => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) + -> Shaped sh a -> Shaped sh a +arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +arithPromoteShaped2 :: forall sh a. KnownShape sh + => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) + -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where + (+) = arithPromoteShaped2 (+) + (-) = arithPromoteShaped2 (-) + (*) = arithPromoteShaped2 (*) + negate = arithPromoteShaped negate + abs = arithPromoteShaped abs + signum = arithPromoteShaped signum + fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) + +deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) +deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) + -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers |