aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 00:18:17 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 00:18:17 +0200
commita0010622885dcb55a916bf3514c0e9040f6871e9 (patch)
tree9e10c18eaf5c873d50e1f88a3bf114179c151769 /src/Data/Array/Nested/Internal.hs
parent4b74d1b1f7c46a4b3907838bee11f669060d3a23 (diff)
Fast numeric operations for Num
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs46
1 files changed, 38 insertions, 8 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 588237d..831a9b5 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFoldable #-}
@@ -62,6 +63,7 @@ import Unsafe.Coerce
import Data.Array.Mixed
import qualified Data.Array.Mixed as X
+import Data.Array.Nested.Internal.Arith
-- Invariant in the API
@@ -999,6 +1001,7 @@ mliftPrim2 :: PrimElt a
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
+{-}
instance (Num a, PrimElt a) => Num (Mixed sh a) where
(+) = mliftPrim2 (+)
(-) = mliftPrim2 (-)
@@ -1008,12 +1011,39 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where
signum = mliftPrim signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
-instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where
+type NumConstr a = Num a
+--}
+
+{--}
+mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr))
+
+mliftNumElt2 :: PrimElt a
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+-- TODO: Clean up this mess and remove NumConstr
+type NumConstr a = NumElt a
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 numEltAdd
+ (-) = mliftNumElt2 numEltSub
+ (*) = mliftNumElt2 numEltMul
+ negate = mliftNumElt1 numEltNeg
+ abs = mliftNumElt1 numEltAbs
+ signum = mliftNumElt1 numEltSignum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+--}
+
+instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Mixed sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
recip = mliftPrim recip
(/) = mliftPrim2 (/)
-instance (Floating a, PrimElt a) => Floating (Mixed sh a) where
+instance (Floating a, PrimElt a, NumConstr a) => Floating (Mixed sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
exp = mliftPrim exp
log = mliftPrim log
@@ -1316,7 +1346,7 @@ arithPromoteRanked2 :: forall n a. PrimElt a
-> Ranked n a -> Ranked n a -> Ranked n a
arithPromoteRanked2 = coerce
-instance (Num a, PrimElt a) => Num (Ranked n a) where
+instance (NumConstr a, PrimElt a) => Num (Ranked n a) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
@@ -1325,12 +1355,12 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where
signum = arithPromoteRanked signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate"
-instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where
+instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Ranked n a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
recip = arithPromoteRanked recip
(/) = arithPromoteRanked2 (/)
-instance (Floating a, PrimElt a) => Floating (Ranked n a) where
+instance (Floating a, PrimElt a, NumConstr a) => Floating (Ranked n a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
exp = arithPromoteRanked exp
log = arithPromoteRanked log
@@ -1616,7 +1646,7 @@ arithPromoteShaped2 :: forall sh a. PrimElt a
-> Shaped sh a -> Shaped sh a -> Shaped sh a
arithPromoteShaped2 = coerce
-instance (Num a, PrimElt a) => Num (Shaped sh a) where
+instance (NumConstr a, PrimElt a) => Num (Shaped sh a) where
(+) = arithPromoteShaped2 (+)
(-) = arithPromoteShaped2 (-)
(*) = arithPromoteShaped2 (*)
@@ -1625,12 +1655,12 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where
signum = arithPromoteShaped signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate"
-instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where
+instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Shaped sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate"
recip = arithPromoteShaped recip
(/) = arithPromoteShaped2 (/)
-instance (Floating a, PrimElt a) => Floating (Shaped sh a) where
+instance (Floating a, PrimElt a, NumConstr a) => Floating (Shaped sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate"
exp = arithPromoteShaped exp
log = arithPromoteShaped log