From 70d2edeb338c6acbe9943c4f8b24225bcb912211 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 13:03:53 +0200 Subject: Num instances for Mixed, Ranked, Shaped --- src/Data/Array/Mixed.hs | 12 +++++++ src/Data/Array/Nested/Internal.hs | 72 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) (limited to 'src') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 7f25d84..d9eb5f0 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -111,6 +111,12 @@ shapeSize IZX = 1 shapeSize (n ::@ sh) = n * shapeSize sh shapeSize (n ::? sh) = n * shapeSize sh +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh) +ssxToShape' SZX = Just IZX +ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh +ssxToShape' (_ :$? _) = Nothing + fromLinearIdx :: IxX sh -> Int -> IxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx @@ -221,6 +227,12 @@ scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a +constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a +constant sh x + | Dict <- lemKnownINatRank sh + , Dict <- knownNatFromINat (Proxy @(Rank sh)) + = XArray (S.constant (shapeLshape sh) x) + generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 86b0fce..f7c383a 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where import Control.Monad (forM_) import Control.Monad.ST +import qualified Data.Array.RankedS as S import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy @@ -353,6 +354,33 @@ mtranspose perm = mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') (X.transpose perm)) +mliftPrim :: (KnownShapeX sh, Storable a) + => (a -> a) + -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (KnownShapeX sh, Storable a) + => (a -> a -> a) + -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = + M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) + +instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where + (+) = mliftPrim2 (+) + (-) = mliftPrim2 (-) + (*) = mliftPrim2 (*) + negate = mliftPrim negate + abs = mliftPrim abs + signum = mliftPrim signum + fromInteger n = + case X.ssxToShape' (knownShapeX @sh) of + Just sh -> M_Primitive (X.constant sh (fromInteger n)) + Nothing -> error "Data.Array.Nested.fromIntegral: \ + \Unknown components in shape, use explicit replicate" + +deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) +deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) + -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'INat'. @@ -582,6 +610,28 @@ coerceMixedXArray = coerce -- ====== API OF RANKED ARRAYS ====== -- +arithPromoteRanked :: forall n a. KnownINat n + => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) + -> Ranked n a -> Ranked n a +arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce + +arithPromoteRanked2 :: forall n a. KnownINat n + => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) + -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce + +instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where + (+) = arithPromoteRanked2 (+) + (-) = arithPromoteRanked2 (-) + (*) = arithPromoteRanked2 (*) + negate = arithPromoteRanked negate + abs = arithPromoteRanked abs + signum = arithPromoteRanked signum + fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) + +deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) +deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) + -- | An index into a rank-typed array. type IxR :: INat -> Type data IxR n where @@ -646,6 +696,28 @@ rtranspose perm (Ranked arr) -- ====== API OF SHAPED ARRAYS ====== -- +arithPromoteShaped :: forall sh a. KnownShape sh + => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) + -> Shaped sh a -> Shaped sh a +arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +arithPromoteShaped2 :: forall sh a. KnownShape sh + => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) + -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where + (+) = arithPromoteShaped2 (+) + (-) = arithPromoteShaped2 (-) + (*) = arithPromoteShaped2 (*) + negate = arithPromoteShaped negate + abs = arithPromoteShaped abs + signum = arithPromoteShaped signum + fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) + +deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) +deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) + -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers -- cgit v1.2.3-70-g09d2