diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 72 | 
2 files changed, 84 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 7f25d84..d9eb5f0 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -111,6 +111,12 @@ shapeSize IZX = 1  shapeSize (n ::@ sh) = n * shapeSize sh  shapeSize (n ::? sh) = n * shapeSize sh +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh) +ssxToShape' SZX = Just IZX +ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh +ssxToShape' (_ :$? _) = Nothing +  fromLinearIdx :: IxX sh -> Int -> IxX sh  fromLinearIdx = \sh i -> case go sh i of    (idx, 0) -> idx @@ -221,6 +227,12 @@ scalar = XArray . S.scalar  unScalar :: Storable a => XArray '[] a -> a  unScalar (XArray a) = S.unScalar a +constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a +constant sh x +  | Dict <- lemKnownINatRank sh +  , Dict <- knownNatFromINat (Proxy @(Rank sh)) +  = XArray (S.constant (shapeLshape sh) x) +  generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a  generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 86b0fce..f7c383a 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where  import Control.Monad (forM_)  import Control.Monad.ST +import qualified Data.Array.RankedS as S  import Data.Coerce (coerce, Coercible)  import Data.Kind  import Data.Proxy @@ -353,6 +354,33 @@ mtranspose perm =    mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')                              (X.transpose perm)) +mliftPrim :: (KnownShapeX sh, Storable a) +          => (a -> a) +          -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (KnownShapeX sh, Storable a) +           => (a -> a -> a) +           -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) +mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = +  M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) + +instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where +  (+) = mliftPrim2 (+) +  (-) = mliftPrim2 (-) +  (*) = mliftPrim2 (*) +  negate = mliftPrim negate +  abs = mliftPrim abs +  signum = mliftPrim signum +  fromInteger n = +    case X.ssxToShape' (knownShapeX @sh) of +      Just sh -> M_Primitive (X.constant sh (fromInteger n)) +      Nothing -> error "Data.Array.Nested.fromIntegral: \ +                       \Unknown components in shape, use explicit replicate" + +deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) +deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) +  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is  -- represented on the type level as a 'INat'. @@ -582,6 +610,28 @@ coerceMixedXArray = coerce  -- ====== API OF RANKED ARRAYS ====== -- +arithPromoteRanked :: forall n a. KnownINat n +                   => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) +                   -> Ranked n a -> Ranked n a +arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce + +arithPromoteRanked2 :: forall n a. KnownINat n +                    => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) +                    -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce + +instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where +  (+) = arithPromoteRanked2 (+) +  (-) = arithPromoteRanked2 (-) +  (*) = arithPromoteRanked2 (*) +  negate = arithPromoteRanked negate +  abs = arithPromoteRanked abs +  signum = arithPromoteRanked signum +  fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) + +deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) +deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) +  -- | An index into a rank-typed array.  type IxR :: INat -> Type  data IxR n where @@ -646,6 +696,28 @@ rtranspose perm (Ranked arr)  -- ====== API OF SHAPED ARRAYS ====== -- +arithPromoteShaped :: forall sh a. KnownShape sh +                   => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) +                   -> Shaped sh a -> Shaped sh a +arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +arithPromoteShaped2 :: forall sh a. KnownShape sh +                    => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) +                    -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce + +instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where +  (+) = arithPromoteShaped2 (+) +  (-) = arithPromoteShaped2 (-) +  (*) = arithPromoteShaped2 (*) +  negate = arithPromoteShaped negate +  abs = arithPromoteShaped abs +  signum = arithPromoteShaped signum +  fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) + +deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) +deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) +  -- | An index into a shape-typed array.  --  -- For convenience, this contains regular 'Int's instead of bounded integers | 
