From 3d2e4a567668ea951e629834e6871a3f144c1b84 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 21 May 2024 11:42:17 +0200 Subject: Add Eq, Fractional, Floating instances --- src/Data/Array/Nested/Internal.hs | 110 ++++++++++++++++++++++++++++++++++---- 1 file changed, 100 insertions(+), 10 deletions(-) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 66b3130..b645f4a 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -53,6 +53,7 @@ import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.C.Types (CInt(..)) import Foreign.Storable (Storable) +import qualified GHC.Float (log1p, expm1, log1pexp, log1mexp) import GHC.IsList (IsList) import qualified GHC.IsList as IsList import GHC.TypeLits @@ -460,16 +461,16 @@ data family Mixed sh a -- ostensibly not exist; the full array is still empty. data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Show) + deriving (Show, Eq) -- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show) -- no content, orthotope optimises this (via Vector) +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq) -- no content, orthotope optimises this (via Vector) -- etc. data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) @@ -995,6 +996,35 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where signum = mliftPrim signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" +instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + recip = mliftPrim recip + (/) = mliftPrim2 (/) + +instance (Floating a, PrimElt a) => Floating (Mixed sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + exp = mliftPrim exp + log = mliftPrim log + sqrt = mliftPrim sqrt + (**) = mliftPrim2 (**) + logBase = mliftPrim2 logBase + sin = mliftPrim sin + cos = mliftPrim cos + tan = mliftPrim tan + asin = mliftPrim asin + acos = mliftPrim acos + atan = mliftPrim atan + sinh = mliftPrim sinh + cosh = mliftPrim cosh + tanh = mliftPrim tanh + asinh = mliftPrim asinh + acosh = mliftPrim acosh + atanh = mliftPrim atanh + log1p = mliftPrim GHC.Float.log1p + expm1 = mliftPrim GHC.Float.expm1 + log1pexp = mliftPrim GHC.Float.log1pexp + log1mexp = mliftPrim GHC.Float.log1mexp + mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a mtoRanked arr | Refl <- X.lemAppNil @sh @@ -1029,6 +1059,7 @@ mcastToShaped arr targetsh type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -- | A shape-typed array: the full shape of the array (the sizes of its -- dimensions) is represented on the type level as a list of 'Nat's. Note that @@ -1042,6 +1073,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -- just unwrap the newtype and defer to the general instance for nested arrays newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) @@ -1277,7 +1309,36 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" + +instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" + recip = arithPromoteRanked recip + (/) = arithPromoteRanked2 (/) + +instance (Floating a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" + exp = arithPromoteRanked exp + log = arithPromoteRanked log + sqrt = arithPromoteRanked sqrt + (**) = arithPromoteRanked2 (**) + logBase = arithPromoteRanked2 logBase + sin = arithPromoteRanked sin + cos = arithPromoteRanked cos + tan = arithPromoteRanked tan + asin = arithPromoteRanked asin + acos = arithPromoteRanked acos + atan = arithPromoteRanked atan + sinh = arithPromoteRanked sinh + cosh = arithPromoteRanked cosh + tanh = arithPromoteRanked tanh + asinh = arithPromoteRanked asinh + acosh = arithPromoteRanked acosh + atanh = arithPromoteRanked atanh + log1p = arithPromoteRanked GHC.Float.log1p + expm1 = arithPromoteRanked GHC.Float.expm1 + log1pexp = arithPromoteRanked GHC.Float.log1pexp + log1mexp = arithPromoteRanked GHC.Float.log1mexp zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR @@ -1548,7 +1609,36 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" + +instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" + recip = arithPromoteShaped recip + (/) = arithPromoteShaped2 (/) + +instance (Floating a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" + exp = arithPromoteShaped exp + log = arithPromoteShaped log + sqrt = arithPromoteShaped sqrt + (**) = arithPromoteShaped2 (**) + logBase = arithPromoteShaped2 logBase + sin = arithPromoteShaped sin + cos = arithPromoteShaped cos + tan = arithPromoteShaped tan + asin = arithPromoteShaped asin + acos = arithPromoteShaped acos + atan = arithPromoteShaped atan + sinh = arithPromoteShaped sinh + cosh = arithPromoteShaped cosh + tanh = arithPromoteShaped tanh + asinh = arithPromoteShaped asinh + acosh = arithPromoteShaped acosh + atanh = arithPromoteShaped atanh + log1p = arithPromoteShaped GHC.Float.log1p + expm1 = arithPromoteShaped GHC.Float.expm1 + log1pexp = arithPromoteShaped GHC.Float.log1pexp + log1mexp = arithPromoteShaped GHC.Float.log1mexp zeroIxS :: ShS sh -> IIxS sh zeroIxS ZSS = ZIS -- cgit v1.2.3-70-g09d2