aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-21 11:42:17 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-21 11:42:17 +0200
commit3d2e4a567668ea951e629834e6871a3f144c1b84 (patch)
tree3c55c1833fd21bfac84b14a360617459ee5d143f /src/Data/Array/Nested/Internal.hs
parentd4086966b95c2ed556f5628a4cfcc41f5e19fab7 (diff)
Add Eq, Fractional, Floating instances
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs110
1 files changed, 100 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 66b3130..b645f4a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -53,6 +53,7 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.C.Types (CInt(..))
import Foreign.Storable (Storable)
+import qualified GHC.Float (log1p, expm1, log1pexp, log1mexp)
import GHC.IsList (IsList)
import qualified GHC.IsList as IsList
import GHC.TypeLits
@@ -460,16 +461,16 @@ data family Mixed sh a
-- ostensibly not exist; the full array is still empty.
data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
- deriving (Show)
+ deriving (Show, Eq)
-- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show)
-newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show)
-newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show)
-newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show)
-newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show)
-newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show)
-newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show) -- no content, orthotope optimises this (via Vector)
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq) -- no content, orthotope optimises this (via Vector)
-- etc.
data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b)
@@ -995,6 +996,35 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where
signum = mliftPrim signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ recip = mliftPrim recip
+ (/) = mliftPrim2 (/)
+
+instance (Floating a, PrimElt a) => Floating (Mixed sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ exp = mliftPrim exp
+ log = mliftPrim log
+ sqrt = mliftPrim sqrt
+ (**) = mliftPrim2 (**)
+ logBase = mliftPrim2 logBase
+ sin = mliftPrim sin
+ cos = mliftPrim cos
+ tan = mliftPrim tan
+ asin = mliftPrim asin
+ acos = mliftPrim acos
+ atan = mliftPrim atan
+ sinh = mliftPrim sinh
+ cosh = mliftPrim cosh
+ tanh = mliftPrim tanh
+ asinh = mliftPrim asinh
+ acosh = mliftPrim acosh
+ atanh = mliftPrim atanh
+ log1p = mliftPrim GHC.Float.log1p
+ expm1 = mliftPrim GHC.Float.expm1
+ log1pexp = mliftPrim GHC.Float.log1pexp
+ log1mexp = mliftPrim GHC.Float.log1mexp
+
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a
mtoRanked arr
| Refl <- X.lemAppNil @sh
@@ -1029,6 +1059,7 @@ mcastToShaped arr targetsh
type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
-- | A shape-typed array: the full shape of the array (the sizes of its
-- dimensions) is represented on the type level as a list of 'Nat's. Note that
@@ -1042,6 +1073,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
-- just unwrap the newtype and defer to the general instance for nested arrays
newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
@@ -1277,7 +1309,36 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate"
+
+instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
+ recip = arithPromoteRanked recip
+ (/) = arithPromoteRanked2 (/)
+
+instance (Floating a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
+ exp = arithPromoteRanked exp
+ log = arithPromoteRanked log
+ sqrt = arithPromoteRanked sqrt
+ (**) = arithPromoteRanked2 (**)
+ logBase = arithPromoteRanked2 logBase
+ sin = arithPromoteRanked sin
+ cos = arithPromoteRanked cos
+ tan = arithPromoteRanked tan
+ asin = arithPromoteRanked asin
+ acos = arithPromoteRanked acos
+ atan = arithPromoteRanked atan
+ sinh = arithPromoteRanked sinh
+ cosh = arithPromoteRanked cosh
+ tanh = arithPromoteRanked tanh
+ asinh = arithPromoteRanked asinh
+ acosh = arithPromoteRanked acosh
+ atanh = arithPromoteRanked atanh
+ log1p = arithPromoteRanked GHC.Float.log1p
+ expm1 = arithPromoteRanked GHC.Float.expm1
+ log1pexp = arithPromoteRanked GHC.Float.log1pexp
+ log1mexp = arithPromoteRanked GHC.Float.log1mexp
zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
@@ -1548,7 +1609,36 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where
negate = arithPromoteShaped negate
abs = arithPromoteShaped abs
signum = arithPromoteShaped signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate"
+
+instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
+ recip = arithPromoteShaped recip
+ (/) = arithPromoteShaped2 (/)
+
+instance (Floating a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
+ exp = arithPromoteShaped exp
+ log = arithPromoteShaped log
+ sqrt = arithPromoteShaped sqrt
+ (**) = arithPromoteShaped2 (**)
+ logBase = arithPromoteShaped2 logBase
+ sin = arithPromoteShaped sin
+ cos = arithPromoteShaped cos
+ tan = arithPromoteShaped tan
+ asin = arithPromoteShaped asin
+ acos = arithPromoteShaped acos
+ atan = arithPromoteShaped atan
+ sinh = arithPromoteShaped sinh
+ cosh = arithPromoteShaped cosh
+ tanh = arithPromoteShaped tanh
+ asinh = arithPromoteShaped asinh
+ acosh = arithPromoteShaped acosh
+ atanh = arithPromoteShaped atanh
+ log1p = arithPromoteShaped GHC.Float.log1p
+ expm1 = arithPromoteShaped GHC.Float.expm1
+ log1pexp = arithPromoteShaped GHC.Float.log1pexp
+ log1mexp = arithPromoteShaped GHC.Float.log1mexp
zeroIxS :: ShS sh -> IIxS sh
zeroIxS ZSS = ZIS