aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
commite80b2593edc3d216905279ebcfa797593a1efbfc (patch)
tree5e5057e03f35369983f6600efc59c438c0cf2366 /src/Data/Array/Nested/Internal/Arith.hs
parent2ac16efe59051e0cdeb37422ab579c8d354d562a (diff)
Fast Fractional ops via C code
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs56
1 files changed, 46 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 7484455..07d5d8a 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -170,16 +170,6 @@ flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
-class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
-
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -194,6 +184,20 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss = varE (afboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -204,6 +208,16 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -255,6 +269,16 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn
| finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
+class NumElt a where
+ numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -334,3 +358,15 @@ instance NumElt CInt where
numEltProduct1Inner = intWidBranchRed @CInt
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+class FloatElt a where
+ floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltRecip = recipVectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltRecip = recipVectorDouble