aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
commite80b2593edc3d216905279ebcfa797593a1efbfc (patch)
tree5e5057e03f35369983f6600efc59c438c0cf2366 /src/Data/Array/Nested/Internal
parent2ac16efe59051e0cdeb37422ab579c8d354d562a (diff)
Fast Fractional ops via C code
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs56
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs18
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs31
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs4
4 files changed, 96 insertions, 13 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 7484455..07d5d8a 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -170,16 +170,6 @@ flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
-class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
-
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -194,6 +184,20 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss = varE (afboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -204,6 +208,16 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -255,6 +269,16 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn
| finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
+class NumElt a where
+ numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -334,3 +358,15 @@ instance NumElt CInt where
numEltProduct1Inner = intWidBranchRed @CInt
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+class FloatElt a where
+ floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltRecip = recipVectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltRecip = recipVectorDouble
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index 49effa1..ac83188 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -24,12 +24,30 @@ $(fmap concat . forM typesList $ \arithtype -> do
[t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "fbinary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
let base = "unary_" ++ atCName arithtype
pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "funary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
let base = "reduce_" ++ atCName arithtype
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 91e50ad..ce2836d 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -14,13 +14,18 @@ data ArithType = ArithType
, atCName :: String -- "i32"
}
+floatTypesList :: [ArithType]
+floatTypesList =
+ [ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
typesList :: [ArithType]
typesList =
[ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
- ,ArithType ''Float "float"
- ,ArithType ''Double "double"
]
+ ++ floatTypesList
-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
$(genArithDataType Binop "ArithBOp")
@@ -37,6 +42,21 @@ $(do clauses <- readArithLists Binop
,return $ FunD (mkName "aboNumOp") clauses])
+-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
+$(genArithDataType FBinop "ArithFBOp")
+
+$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
+$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
+
+$(do clauses <- readArithLists FBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
+ ,return $ FunD (mkName "afboNumOp") clauses])
+
+
-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
$(genArithDataType Unop "ArithUOp")
@@ -44,6 +64,13 @@ $(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
+$(genArithDataType FUnop "ArithFUOp")
+
+$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
+$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
+
+
-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
$(genArithDataType Redop "ArithRedOp")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
index b748b97..b40a066 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
@@ -9,7 +9,7 @@ import Language.Haskell.TH
import Text.Read
-data OpKind = Binop | Unop | Redop
+data OpKind = Binop | FBinop | Unop | FUnop | Redop
deriving (Show, Eq)
readArithLists :: OpKind
@@ -46,7 +46,9 @@ readArithLists targetkind fop fcombine = do
parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
parseKind "BINOP" = Just Binop
+ parseKind "FBINOP" = Just FBinop
parseKind "UNOP" = Just Unop
+ parseKind "FUNOP" = Just FUnop
parseKind "REDOP" = Just Redop
parseKind _ = Nothing