diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 14:57:34 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 14:57:34 +0200 | 
| commit | e80b2593edc3d216905279ebcfa797593a1efbfc (patch) | |
| tree | 5e5057e03f35369983f6600efc59c438c0cf2366 | |
| parent | 2ac16efe59051e0cdeb37422ab579c8d354d562a (diff) | |
Fast Fractional ops via C code
| -rw-r--r-- | bench/Main.hs | 13 | ||||
| -rw-r--r-- | cbits/arith.c | 60 | ||||
| -rw-r--r-- | cbits/arith_lists.h | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 16 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 56 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 18 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 31 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 4 | 
8 files changed, 178 insertions, 24 deletions
| diff --git a/bench/Main.hs b/bench/Main.hs index c4d2879..8f3b670 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -20,6 +20,10 @@ main = defaultMain        let n = 1_000_000        in nf (\(a, b) -> runScalar (rsumOuter1 (arithPromoteRanked2 (mliftPrim2 (*)) a b)))              (riota @Double n, riota n) +    ,bench "sum(/) Double [1e6]" $ +      let n = 1_000_000 +      in nf (\(a, b) -> runScalar (rsumOuter1 (arithPromoteRanked2 (mliftPrim2 (/)) a b))) +            (riota @Double n, riota n)      ,bench "sum Double [1e6]" $        let n = 1_000_000        in nf (\a -> runScalar (rsumOuter1 a)) @@ -34,6 +38,10 @@ main = defaultMain        let n = 1_000_000        in nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))              (riota @Double n, riota n) +    ,bench "sum(/) Double [1e6]" $ +      let n = 1_000_000 +      in nf (\(a, b) -> runScalar (rsumOuter1 (a / b))) +            (riota @Double n, riota n)      ,bench "sum Double [1e6]" $        let n = 1_000_000        in nf (\a -> runScalar (rsumOuter1 a)) @@ -50,6 +58,11 @@ main = defaultMain        in nf (\(a, b) -> LA.sumElements (a * b))              (LA.linspace @Double n (0.0, fromIntegral (n - 1))              ,LA.linspace @Double n (0.0, fromIntegral (n - 1))) +    ,bench "sum(/) Double [1e6]" $ +      let n = 1_000_000 +      in nf (\(a, b) -> LA.sumElements (a / b)) +            (LA.linspace @Double n (0.0, fromIntegral (n - 1)) +            ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))      ,bench "sum Double [1e6]" $        let n = 1_000_000        in nf (\a -> LA.sumElements a) diff --git a/cbits/arith.c b/cbits/arith.c index 65cdb41..a71c1b9 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -8,7 +8,9 @@  // These are the wrapper macros used in arith_lists.h. Preset them to empty to  // avoid having to touch macros unrelated to the particular operation set below.  #define LIST_BINOP(name, id, hsop) +#define LIST_FBINOP(name, id, hsop)  #define LIST_UNOP(name, id, _) +#define LIST_FUNOP(name, id, _)  #define LIST_REDOP(name, id, _) @@ -147,6 +149,34 @@ enum binop_tag_t {      } \    } +enum fbinop_tag_t { +#undef LIST_FBINOP +#define LIST_FBINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_FBINOP +#define LIST_FBINOP(name, id, hsop) +}; + +#define ENTRY_FBINARY_OPS(typ) \ +  void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ +    switch (tag) { \ +      case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \ +      default: wrong_op("binary_sv", tag); \ +    } \ +  } \ +  void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ +    switch (tag) { \ +      case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \ +      default: wrong_op("binary_vs", tag); \ +    } \ +  } \ +  void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ +    switch (tag) { \ +      case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \ +      default: wrong_op("binary_vv", tag); \ +    } \ +  } +  enum unop_tag_t {  #undef LIST_UNOP  #define LIST_UNOP(name, id, _) name = id, @@ -165,6 +195,22 @@ enum unop_tag_t {      } \    } +enum funop_tag_t { +#undef LIST_FUNOP +#define LIST_FUNOP(name, id, _) name = id, +#include "arith_lists.h" +#undef LIST_FUNOP +#define LIST_FUNOP(name, id, _) +}; + +#define ENTRY_FUNARY_OPS(typ) \ +  void oxarop_funary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \ +    switch (tag) { \ +      case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \ +      default: wrong_op("unary", tag); \ +    } \ +  } +  enum redop_tag_t {  #undef LIST_REDOP  #define LIST_REDOP(name, id, _) name = id, @@ -187,8 +233,8 @@ enum redop_tag_t {   *                        Generate all the functions                         *   *****************************************************************************/ -#define NUM_TYPES_LOOP_XLIST \ -  X(i32) X(i64) X(double) X(float) +#define FLOAT_TYPES_XLIST X(double) X(float) +#define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST  #define X(typ) \    COMM_OP(add, +, typ) \ @@ -202,5 +248,13 @@ enum redop_tag_t {    ENTRY_BINARY_OPS(typ) \    ENTRY_UNARY_OPS(typ) \    ENTRY_REDUCE_OPS(typ) -NUM_TYPES_LOOP_XLIST +NUM_TYPES_XLIST +#undef X + +#define X(typ) \ +  NONCOMM_OP(fdiv, /, typ) \ +  UNARY_OP(recip, 1.0/, typ) \ +  ENTRY_FBINARY_OPS(typ) \ +  ENTRY_FUNARY_OPS(typ) +FLOAT_TYPES_XLIST  #undef X diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index c7495e8..1137c18 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -2,9 +2,13 @@ LIST_BINOP(BO_ADD, 1, +)  LIST_BINOP(BO_SUB, 2, -)  LIST_BINOP(BO_MUL, 3, *) +LIST_FBINOP(FB_DIV, 1, /) +  LIST_UNOP(UO_NEG, 1,)  LIST_UNOP(UO_ABS, 2,)  LIST_UNOP(UO_SIGNUM, 3,) +LIST_FUNOP(FU_RECIP, 1,) +  LIST_REDOP(RO_SUM1, 1,)  LIST_REDOP(RO_PRODUCT1, 2,) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 94f08bf..ef2ad6b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1048,12 +1048,12 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where    signum = mliftNumElt1 numEltSignum    fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" -instance (NumElt a, PrimElt a, Fractional a) => Fractional (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Mixed sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" -  recip = mliftPrim recip -  (/) = mliftPrim2 (/) +  recip = mliftNumElt1 floatEltRecip +  (/) = mliftNumElt2 floatEltDiv -instance (NumElt a, PrimElt a, Floating a) => Floating (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Mixed sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"    exp = mliftPrim exp    log = mliftPrim log @@ -1367,12 +1367,12 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where    signum = arithPromoteRanked signum    fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" -instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"    recip = arithPromoteRanked recip    (/) = arithPromoteRanked2 (/) -instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"    exp = arithPromoteRanked exp    log = arithPromoteRanked log @@ -1698,12 +1698,12 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where    signum = arithPromoteShaped signum    fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" -instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"    recip = arithPromoteShaped recip    (/) = arithPromoteShaped2 (/) -instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"    exp = arithPromoteShaped exp    log = arithPromoteShaped log diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 7484455..07d5d8a 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -170,16 +170,6 @@ flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO ()  flipOp f n out v s = f n out s v -class NumElt a where -  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a -  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -194,6 +184,20 @@ $(fmap concat . forM typesList $ \arithtype -> do                 ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]                     return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_fbinary_" ++ atCName arithtype +          c_ss = varE (afboNumOp arithop) +          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) +  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -204,6 +208,16 @@ $(fmap concat . forM typesList $ \arithtype -> do                 ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]                     return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) +  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -255,6 +269,16 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn    | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64    | otherwise = error "Unsupported Int width" +class NumElt a where +  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a +  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  instance NumElt Int32 where    numEltAdd = addVectorInt32    numEltSub = subVectorInt32 @@ -334,3 +358,15 @@ instance NumElt CInt where    numEltProduct1Inner = intWidBranchRed @CInt                            (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))                            (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where +  floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where +  floatEltDiv = divVectorFloat +  floatEltRecip = recipVectorFloat + +instance FloatElt Double where +  floatEltDiv = divVectorDouble +  floatEltRecip = recipVectorDouble diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index 49effa1..ac83188 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -24,12 +24,30 @@ $(fmap concat . forM typesList $ \arithtype -> do                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])        ]) +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "fbinary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) +  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      let base = "unary_" ++ atCName arithtype      pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>        [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "funary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      let base = "reduce_" ++ atCName arithtype diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 91e50ad..ce2836d 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -14,13 +14,18 @@ data ArithType = ArithType    , atCName :: String  -- "i32"    } +floatTypesList :: [ArithType] +floatTypesList = +  [ArithType ''Float "float" +  ,ArithType ''Double "double" +  ] +  typesList :: [ArithType]  typesList =    [ArithType ''Int32 "i32"    ,ArithType ''Int64 "i64" -  ,ArithType ''Float "float" -  ,ArithType ''Double "double"    ] +  ++ floatTypesList  -- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)  $(genArithDataType Binop "ArithBOp") @@ -37,6 +42,21 @@ $(do clauses <- readArithLists Binop                ,return $ FunD (mkName "aboNumOp") clauses]) +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] +              ,return $ FunD (mkName "afboNumOp") clauses]) + +  -- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)  $(genArithDataType Unop "ArithUOp") @@ -44,6 +64,13 @@ $(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))  $(genArithEnumFun Unop ''ArithUOp "auoEnum") +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + +  -- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)  $(genArithDataType Redop "ArithRedOp") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs index b748b97..b40a066 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs @@ -9,7 +9,7 @@ import Language.Haskell.TH  import Text.Read -data OpKind = Binop | Unop | Redop +data OpKind = Binop | FBinop | Unop | FUnop | Redop    deriving (Show, Eq)  readArithLists :: OpKind @@ -46,7 +46,9 @@ readArithLists targetkind fop fcombine = do      parseField s = break (`elem` ",)") (dropWhile (== ' ') s)      parseKind "BINOP" = Just Binop +    parseKind "FBINOP" = Just FBinop      parseKind "UNOP" = Just Unop +    parseKind "FUNOP" = Just FUnop      parseKind "REDOP" = Just Redop      parseKind _ = Nothing | 
