summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 13:03:53 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 13:03:53 +0200
commit70d2edeb338c6acbe9943c4f8b24225bcb912211 (patch)
tree3c6570d2a9f8eb9f89b16de092e21ad516b9349c /src
parent478875e9d82e8c645cbea2e41362c312e892488a (diff)
Num instances for Mixed, Ranked, Shaped
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed.hs12
-rw-r--r--src/Data/Array/Nested/Internal.hs72
2 files changed, 84 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 7f25d84..d9eb5f0 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -111,6 +111,12 @@ shapeSize IZX = 1
shapeSize (n ::@ sh) = n * shapeSize sh
shapeSize (n ::? sh) = n * shapeSize sh
+-- | This may fail if @sh@ has @Nothing@s in it.
+ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh)
+ssxToShape' SZX = Just IZX
+ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh
+ssxToShape' (_ :$? _) = Nothing
+
fromLinearIdx :: IxX sh -> Int -> IxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -221,6 +227,12 @@ scalar = XArray . S.scalar
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
+constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a
+constant sh x
+ | Dict <- lemKnownINatRank sh
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ = XArray (S.constant (shapeLshape sh) x)
+
generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 86b0fce..f7c383a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where
import Control.Monad (forM_)
import Control.Monad.ST
+import qualified Data.Array.RankedS as S
import Data.Coerce (coerce, Coercible)
import Data.Kind
import Data.Proxy
@@ -353,6 +354,33 @@ mtranspose perm =
mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
(X.transpose perm))
+mliftPrim :: (KnownShapeX sh, Storable a)
+ => (a -> a)
+ -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
+mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: (KnownShapeX sh, Storable a)
+ => (a -> a -> a)
+ -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
+mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) =
+ M_Primitive (X.XArray (S.zipWithA f arr1 arr2))
+
+instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where
+ (+) = mliftPrim2 (+)
+ (-) = mliftPrim2 (-)
+ (*) = mliftPrim2 (*)
+ negate = mliftPrim negate
+ abs = mliftPrim abs
+ signum = mliftPrim signum
+ fromInteger n =
+ case X.ssxToShape' (knownShapeX @sh) of
+ Just sh -> M_Primitive (X.constant sh (fromInteger n))
+ Nothing -> error "Data.Array.Nested.fromIntegral: \
+ \Unknown components in shape, use explicit replicate"
+
+deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)
+deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double)
+
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'INat'.
@@ -582,6 +610,28 @@ coerceMixedXArray = coerce
-- ====== API OF RANKED ARRAYS ====== --
+arithPromoteRanked :: forall n a. KnownINat n
+ => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a
+arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
+
+arithPromoteRanked2 :: forall n a. KnownINat n
+ => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a -> Ranked n a
+arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
+
+instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+ (+) = arithPromoteRanked2 (+)
+ (-) = arithPromoteRanked2 (-)
+ (*) = arithPromoteRanked2 (*)
+ negate = arithPromoteRanked negate
+ abs = arithPromoteRanked abs
+ signum = arithPromoteRanked signum
+ fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n)
+
+deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+
-- | An index into a rank-typed array.
type IxR :: INat -> Type
data IxR n where
@@ -646,6 +696,28 @@ rtranspose perm (Ranked arr)
-- ====== API OF SHAPED ARRAYS ====== --
+arithPromoteShaped :: forall sh a. KnownShape sh
+ => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a
+arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+
+arithPromoteShaped2 :: forall sh a. KnownShape sh
+ => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a -> Shaped sh a
+arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+
+instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where
+ (+) = arithPromoteShaped2 (+)
+ (-) = arithPromoteShaped2 (-)
+ (*) = arithPromoteShaped2 (*)
+ negate = arithPromoteShaped negate
+ abs = arithPromoteShaped abs
+ signum = arithPromoteShaped signum
+ fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n)
+
+deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
+deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
+
-- | An index into a shape-typed array.
--
-- For convenience, this contains regular 'Int's instead of bounded integers