aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
commite80b2593edc3d216905279ebcfa797593a1efbfc (patch)
tree5e5057e03f35369983f6600efc59c438c0cf2366
parent2ac16efe59051e0cdeb37422ab579c8d354d562a (diff)
Fast Fractional ops via C code
-rw-r--r--bench/Main.hs13
-rw-r--r--cbits/arith.c60
-rw-r--r--cbits/arith_lists.h4
-rw-r--r--src/Data/Array/Nested/Internal.hs16
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs56
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs18
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs31
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs4
8 files changed, 178 insertions, 24 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index c4d2879..8f3b670 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -20,6 +20,10 @@ main = defaultMain
let n = 1_000_000
in nf (\(a, b) -> runScalar (rsumOuter1 (arithPromoteRanked2 (mliftPrim2 (*)) a b)))
(riota @Double n, riota n)
+ ,bench "sum(/) Double [1e6]" $
+ let n = 1_000_000
+ in nf (\(a, b) -> runScalar (rsumOuter1 (arithPromoteRanked2 (mliftPrim2 (/)) a b)))
+ (riota @Double n, riota n)
,bench "sum Double [1e6]" $
let n = 1_000_000
in nf (\a -> runScalar (rsumOuter1 a))
@@ -34,6 +38,10 @@ main = defaultMain
let n = 1_000_000
in nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
(riota @Double n, riota n)
+ ,bench "sum(/) Double [1e6]" $
+ let n = 1_000_000
+ in nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
+ (riota @Double n, riota n)
,bench "sum Double [1e6]" $
let n = 1_000_000
in nf (\a -> runScalar (rsumOuter1 a))
@@ -50,6 +58,11 @@ main = defaultMain
in nf (\(a, b) -> LA.sumElements (a * b))
(LA.linspace @Double n (0.0, fromIntegral (n - 1))
,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "sum(/) Double [1e6]" $
+ let n = 1_000_000
+ in nf (\(a, b) -> LA.sumElements (a / b))
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
,bench "sum Double [1e6]" $
let n = 1_000_000
in nf (\a -> LA.sumElements a)
diff --git a/cbits/arith.c b/cbits/arith.c
index 65cdb41..a71c1b9 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -8,7 +8,9 @@
// These are the wrapper macros used in arith_lists.h. Preset them to empty to
// avoid having to touch macros unrelated to the particular operation set below.
#define LIST_BINOP(name, id, hsop)
+#define LIST_FBINOP(name, id, hsop)
#define LIST_UNOP(name, id, _)
+#define LIST_FUNOP(name, id, _)
#define LIST_REDOP(name, id, _)
@@ -147,6 +149,34 @@ enum binop_tag_t {
} \
}
+enum fbinop_tag_t {
+#undef LIST_FBINOP
+#define LIST_FBINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_FBINOP
+#define LIST_FBINOP(name, id, hsop)
+};
+
+#define ENTRY_FBINARY_OPS(typ) \
+ void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \
+ switch (tag) { \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \
+ default: wrong_op("binary_sv", tag); \
+ } \
+ } \
+ void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \
+ switch (tag) { \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \
+ default: wrong_op("binary_vs", tag); \
+ } \
+ } \
+ void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \
+ switch (tag) { \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \
+ default: wrong_op("binary_vv", tag); \
+ } \
+ }
+
enum unop_tag_t {
#undef LIST_UNOP
#define LIST_UNOP(name, id, _) name = id,
@@ -165,6 +195,22 @@ enum unop_tag_t {
} \
}
+enum funop_tag_t {
+#undef LIST_FUNOP
+#define LIST_FUNOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_FUNOP
+#define LIST_FUNOP(name, id, _)
+};
+
+#define ENTRY_FUNARY_OPS(typ) \
+ void oxarop_funary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \
+ switch (tag) { \
+ case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \
+ default: wrong_op("unary", tag); \
+ } \
+ }
+
enum redop_tag_t {
#undef LIST_REDOP
#define LIST_REDOP(name, id, _) name = id,
@@ -187,8 +233,8 @@ enum redop_tag_t {
* Generate all the functions *
*****************************************************************************/
-#define NUM_TYPES_LOOP_XLIST \
- X(i32) X(i64) X(double) X(float)
+#define FLOAT_TYPES_XLIST X(double) X(float)
+#define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST
#define X(typ) \
COMM_OP(add, +, typ) \
@@ -202,5 +248,13 @@ enum redop_tag_t {
ENTRY_BINARY_OPS(typ) \
ENTRY_UNARY_OPS(typ) \
ENTRY_REDUCE_OPS(typ)
-NUM_TYPES_LOOP_XLIST
+NUM_TYPES_XLIST
+#undef X
+
+#define X(typ) \
+ NONCOMM_OP(fdiv, /, typ) \
+ UNARY_OP(recip, 1.0/, typ) \
+ ENTRY_FBINARY_OPS(typ) \
+ ENTRY_FUNARY_OPS(typ)
+FLOAT_TYPES_XLIST
#undef X
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index c7495e8..1137c18 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -2,9 +2,13 @@ LIST_BINOP(BO_ADD, 1, +)
LIST_BINOP(BO_SUB, 2, -)
LIST_BINOP(BO_MUL, 3, *)
+LIST_FBINOP(FB_DIV, 1, /)
+
LIST_UNOP(UO_NEG, 1,)
LIST_UNOP(UO_ABS, 2,)
LIST_UNOP(UO_SIGNUM, 3,)
+LIST_FUNOP(FU_RECIP, 1,)
+
LIST_REDOP(RO_SUM1, 1,)
LIST_REDOP(RO_PRODUCT1, 2,)
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 94f08bf..ef2ad6b 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1048,12 +1048,12 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
signum = mliftNumElt1 numEltSignum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
-instance (NumElt a, PrimElt a, Fractional a) => Fractional (Mixed sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Mixed sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
- recip = mliftPrim recip
- (/) = mliftPrim2 (/)
+ recip = mliftNumElt1 floatEltRecip
+ (/) = mliftNumElt2 floatEltDiv
-instance (NumElt a, PrimElt a, Floating a) => Floating (Mixed sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Mixed sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
exp = mliftPrim exp
log = mliftPrim log
@@ -1367,12 +1367,12 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where
signum = arithPromoteRanked signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
-instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where
+instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
recip = arithPromoteRanked recip
(/) = arithPromoteRanked2 (/)
-instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where
+instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
exp = arithPromoteRanked exp
log = arithPromoteRanked log
@@ -1698,12 +1698,12 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
signum = arithPromoteShaped signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
-instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
recip = arithPromoteShaped recip
(/) = arithPromoteShaped2 (/)
-instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
exp = arithPromoteShaped exp
log = arithPromoteShaped log
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 7484455..07d5d8a 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -170,16 +170,6 @@ flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
-class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
-
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -194,6 +184,20 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss = varE (afboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -204,6 +208,16 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -255,6 +269,16 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn
| finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
+class NumElt a where
+ numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -334,3 +358,15 @@ instance NumElt CInt where
numEltProduct1Inner = intWidBranchRed @CInt
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+class FloatElt a where
+ floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltRecip = recipVectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltRecip = recipVectorDouble
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index 49effa1..ac83188 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -24,12 +24,30 @@ $(fmap concat . forM typesList $ \arithtype -> do
[t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "fbinary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
let base = "unary_" ++ atCName arithtype
pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "funary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
let base = "reduce_" ++ atCName arithtype
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 91e50ad..ce2836d 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -14,13 +14,18 @@ data ArithType = ArithType
, atCName :: String -- "i32"
}
+floatTypesList :: [ArithType]
+floatTypesList =
+ [ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
typesList :: [ArithType]
typesList =
[ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
- ,ArithType ''Float "float"
- ,ArithType ''Double "double"
]
+ ++ floatTypesList
-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
$(genArithDataType Binop "ArithBOp")
@@ -37,6 +42,21 @@ $(do clauses <- readArithLists Binop
,return $ FunD (mkName "aboNumOp") clauses])
+-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
+$(genArithDataType FBinop "ArithFBOp")
+
+$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
+$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
+
+$(do clauses <- readArithLists FBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
+ ,return $ FunD (mkName "afboNumOp") clauses])
+
+
-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
$(genArithDataType Unop "ArithUOp")
@@ -44,6 +64,13 @@ $(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
+$(genArithDataType FUnop "ArithFUOp")
+
+$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
+$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
+
+
-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
$(genArithDataType Redop "ArithRedOp")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
index b748b97..b40a066 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
@@ -9,7 +9,7 @@ import Language.Haskell.TH
import Text.Read
-data OpKind = Binop | Unop | Redop
+data OpKind = Binop | FBinop | Unop | FUnop | Redop
deriving (Show, Eq)
readArithLists :: OpKind
@@ -46,7 +46,9 @@ readArithLists targetkind fop fcombine = do
parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
parseKind "BINOP" = Just Binop
+ parseKind "FBINOP" = Just FBinop
parseKind "UNOP" = Just Unop
+ parseKind "FUNOP" = Just FUnop
parseKind "REDOP" = Just Redop
parseKind _ = Nothing