diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 13:03:53 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 13:03:53 +0200 | 
| commit | 70d2edeb338c6acbe9943c4f8b24225bcb912211 (patch) | |
| tree | 3c6570d2a9f8eb9f89b16de092e21ad516b9349c /src/Data/Array/Nested | |
| parent | 478875e9d82e8c645cbea2e41362c312e892488a (diff) | |
Num instances for Mixed, Ranked, Shaped
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 72 | 
1 files changed, 72 insertions, 0 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 86b0fce..f7c383a 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where  import Control.Monad (forM_)  import Control.Monad.ST +import qualified Data.Array.RankedS as S  import Data.Coerce (coerce, Coercible)  import Data.Kind  import Data.Proxy @@ -353,6 +354,33 @@ mtranspose perm =    mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')                              (X.transpose perm)) +mliftPrim :: (KnownShapeX sh, Storable a) +          => (a -> a) +          -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (KnownShapeX sh, Storable a) +           => (a -> a -> a) +           -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = +  M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) + +instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where +  (+) = mliftPrim2 (+) +  (-) = mliftPrim2 (-) +  (*) = mliftPrim2 (*) +  negate = mliftPrim negate +  abs = mliftPrim abs +  signum = mliftPrim signum +  fromInteger n = +    case X.ssxToShape' (knownShapeX @sh) of +      Just sh -> M_Primitive (X.constant sh (fromInteger n)) +      Nothing -> error "Data.Array.Nested.fromIntegral: \ +                       \Unknown components in shape, use explicit replicate" + +deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) +deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) +  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is  -- represented on the type level as a 'INat'. @@ -582,6 +610,28 @@ coerceMixedXArray = coerce  -- ====== API OF RANKED ARRAYS ====== -- +arithPromoteRanked :: forall n a. KnownINat n +                   => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) +                   -> Ranked n a -> Ranked n a +arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce + +arithPromoteRanked2 :: forall n a. KnownINat n +                    => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) +                    -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce + +instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where +  (+) = arithPromoteRanked2 (+) +  (-) = arithPromoteRanked2 (-) +  (*) = arithPromoteRanked2 (*) +  negate = arithPromoteRanked negate +  abs = arithPromoteRanked abs +  signum = arithPromoteRanked signum +  fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) + +deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) +deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) +  -- | An index into a rank-typed array.  type IxR :: INat -> Type  data IxR n where @@ -646,6 +696,28 @@ rtranspose perm (Ranked arr)  -- ====== API OF SHAPED ARRAYS ====== -- +arithPromoteShaped :: forall sh a. KnownShape sh +                   => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) +                   -> Shaped sh a -> Shaped sh a +arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +arithPromoteShaped2 :: forall sh a. KnownShape sh +                    => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) +                    -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where +  (+) = arithPromoteShaped2 (+) +  (-) = arithPromoteShaped2 (-) +  (*) = arithPromoteShaped2 (*) +  negate = arithPromoteShaped negate +  abs = arithPromoteShaped abs +  signum = arithPromoteShaped signum +  fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) + +deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) +deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) +  -- | An index into a shape-typed array.  --  -- For convenience, this contains regular 'Int's instead of bounded integers | 
