From a0010622885dcb55a916bf3514c0e9040f6871e9 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 23 May 2024 00:18:17 +0200 Subject: Fast numeric operations for Num --- .gitignore | 1 + bench/Main.hs | 27 ++- cbits/arith.c | 49 +++++ ox-arrays.cabal | 9 + src/Data/Array/Nested/Internal.hs | 46 ++++- src/Data/Array/Nested/Internal/Arith.hs | 240 ++++++++++++++++++++++++ src/Data/Array/Nested/Internal/Arith/Foreign.hs | 33 ++++ src/Data/Array/Nested/Internal/Arith/Lists.hs | 47 +++++ 8 files changed, 442 insertions(+), 10 deletions(-) create mode 100644 cbits/arith.c create mode 100644 src/Data/Array/Nested/Internal/Arith.hs create mode 100644 src/Data/Array/Nested/Internal/Arith/Foreign.hs create mode 100644 src/Data/Array/Nested/Internal/Arith/Lists.hs diff --git a/.gitignore b/.gitignore index a3ac1fc..56ab906 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ dist-newstyle/ cabal.project.local +.ccls-cache/ diff --git a/bench/Main.hs b/bench/Main.hs index d8582fe..c1fc150 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -2,6 +2,7 @@ {-# LANGUAGE TypeApplications #-} module Main where +import qualified Numeric.LinearAlgebra as LA import Test.Tasty.Bench import Data.Array.Nested @@ -16,5 +17,27 @@ main = defaultMain (riota @Double n, riota n) ,bench "sum(*) Double [1e6]" $ let n = 1_000_000 - in nf (\(a, b) -> runScalar (rsumOuter1 (a + b))) - (riota @Double n, riota n)]] + in nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + (riota @Double n, riota n) + ,bench "sum Double [1e6]" $ + let n = 1_000_000 + in nf (\a -> runScalar (rsumOuter1 a)) + (riota @Double n) + ] + ,bgroup "hmatrix" + [bench "sum(+) Double [1e6]" $ + let n = 1_000_000 + in nf (\(a, b) -> LA.sumElements (a + b)) + (LA.linspace @Double n (fromIntegral 0, fromIntegral (n - 1)) + ,LA.linspace @Double n (fromIntegral 0, fromIntegral (n - 1))) + ,bench "sum(*) Double [1e6]" $ + let n = 1_000_000 + in nf (\(a, b) -> LA.sumElements (a * b)) + (LA.linspace @Double n (fromIntegral 0, fromIntegral (n - 1)) + ,LA.linspace @Double n (fromIntegral 0, fromIntegral (n - 1))) + ,bench "sum Double [1e6]" $ + let n = 1_000_000 + in nf (\a -> LA.sumElements a) + (LA.linspace @Double n (fromIntegral 0, fromIntegral (n - 1))) + ] + ] diff --git a/cbits/arith.c b/cbits/arith.c new file mode 100644 index 0000000..02c8ce1 --- /dev/null +++ b/cbits/arith.c @@ -0,0 +1,49 @@ +#include +#include +#include + +typedef int32_t i32; +typedef int64_t i64; + +#define COMM_OP(name, op, typ) \ + void oxarop_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ + } \ + void oxarop_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ + } + +#define NONCOMM_OP(name, op, typ) \ + COMM_OP(name, op, typ) \ + void oxarop_ ## name ## _ ## typ ## _vs(i64 n, typ *out, typ *x, typ y) { \ + for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ + } + +#define UNARY_OP(name, op, typ) \ + void oxarop_ ## name ## _ ## typ(i64 n, typ *out, typ *x) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ + } + +#define GEN_ABS(x) \ + _Generic((x), \ + int: abs, \ + long: labs, \ + long long: llabs, \ + float: fabsf, \ + double: fabs)(x) + +// This does not result in multiple loads with GCC 13. +#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) + +#define NUM_TYPES_LOOP_XLIST \ + X(i32) X(i64) X(double) X(float) + +#define X(typ) \ + COMM_OP(add, +, typ) \ + NONCOMM_OP(sub, -, typ) \ + COMM_OP(mul, *, typ) \ + UNARY_OP(neg, -, typ) \ + UNARY_OP(abs, GEN_ABS, typ) \ + UNARY_OP(signum, GEN_SIGNUM, typ) +NUM_TYPES_LOOP_XLIST +#undef X diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 58fccf9..3f4fa5b 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -10,15 +10,23 @@ library Data.Array.Mixed Data.Array.Nested Data.Array.Nested.Internal + Data.Array.Nested.Internal.Arith + Data.Array.Nested.Internal.Arith.Foreign + Data.Array.Nested.Internal.Arith.Lists build-depends: base >=4.18 && <4.20, ghc-typelits-knownnat, ghc-typelits-natnormalise, orthotope, + template-haskell, vector hs-source-dirs: src + c-sources: cbits/arith.c + -- hmatrix assumes sse2, so we can too + cc-options: -O3 -msse2 -Wall -Wextra default-language: Haskell2010 ghc-options: -Wall + other-extensions: TemplateHaskell test-suite example type: exitcode-stdio-1.0 @@ -36,6 +44,7 @@ benchmark bench build-depends: ox-arrays, base, + hmatrix, tasty-bench hs-source-dirs: bench default-language: Haskell2010 diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 588237d..831a9b5 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveFoldable #-} @@ -62,6 +63,7 @@ import Unsafe.Coerce import Data.Array.Mixed import qualified Data.Array.Mixed as X +import Data.Array.Nested.Internal.Arith -- Invariant in the API @@ -999,6 +1001,7 @@ mliftPrim2 :: PrimElt a mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) +{-} instance (Num a, PrimElt a) => Num (Mixed sh a) where (+) = mliftPrim2 (+) (-) = mliftPrim2 (-) @@ -1008,12 +1011,39 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where signum = mliftPrim signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" -instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where +type NumConstr a = Num a +--} + +{--} +mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr)) + +mliftNumElt2 :: PrimElt a + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +-- TODO: Clean up this mess and remove NumConstr +type NumConstr a = NumElt a + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 numEltAdd + (-) = mliftNumElt2 numEltSub + (*) = mliftNumElt2 numEltMul + negate = mliftNumElt1 numEltNeg + abs = mliftNumElt1 numEltAbs + signum = mliftNumElt1 numEltSignum + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" +--} + +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" recip = mliftPrim recip (/) = mliftPrim2 (/) -instance (Floating a, PrimElt a) => Floating (Mixed sh a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" exp = mliftPrim exp log = mliftPrim log @@ -1316,7 +1346,7 @@ arithPromoteRanked2 :: forall n a. PrimElt a -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 = coerce -instance (Num a, PrimElt a) => Num (Ranked n a) where +instance (NumConstr a, PrimElt a) => Num (Ranked n a) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) @@ -1325,12 +1355,12 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where signum = arithPromoteRanked signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" -instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Ranked n a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" recip = arithPromoteRanked recip (/) = arithPromoteRanked2 (/) -instance (Floating a, PrimElt a) => Floating (Ranked n a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Ranked n a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" exp = arithPromoteRanked exp log = arithPromoteRanked log @@ -1616,7 +1646,7 @@ arithPromoteShaped2 :: forall sh a. PrimElt a -> Shaped sh a -> Shaped sh a -> Shaped sh a arithPromoteShaped2 = coerce -instance (Num a, PrimElt a) => Num (Shaped sh a) where +instance (NumConstr a, PrimElt a) => Num (Shaped sh a) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) @@ -1625,12 +1655,12 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where signum = arithPromoteShaped signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" -instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Shaped sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) -instance (Floating a, PrimElt a) => Floating (Shaped sh a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Shaped sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate" exp = arithPromoteShaped exp log = arithPromoteShaped log diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs new file mode 100644 index 0000000..4312cd5 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -0,0 +1,240 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Nested.Internal.Arith.Foreign +import Data.Array.Nested.Internal.Arith.Lists + + +mliftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just prefixSz <- stridesDense sh strides = + let vec' = f (VS.slice offset prefixSz vec) + in RS.A (RG.A sh (OI.T strides 0 vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +mliftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +mliftVEltwise2 SNat f + arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) + arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) + | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs + | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of + (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) + in RS.A (RG.A sh1 (OI.T strides1 0 vec')) + (Just 1, Just n) -> -- scalar * dense + RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) + (Just n, Just 1) -> -- dense * scalar + RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) + (_, _) -> -- fallback case + RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = + -- sort dimensions on their stride, ascending + case sort (zip str sh) of + [] -> Just 0 + (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' + _ -> error "Orthotope array's shape vector and stride vector have different lengths" + where + -- Given size of currently densely covered region at beginning of the + -- array, the remaining shape vector and the corresponding remaining stride + -- vector, return whether this all together covers a dense prefix of the + -- array. If it does, return the number of elements in this prefix. + checkCover :: Int -> [Int] -> [Int] -> Maybe Int + checkCover block [] [] = Just block + checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' + checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +vectorOp1 :: forall a b. Storable a + => (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO ()) + -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length v) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith v $ \pv -> + f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) + VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +vectorOp2 :: forall a b. Storable a + => (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv + -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs + -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv + -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases + (Left x) (Left y) -> VS.singleton (fss x y) + + (Left x) (Right vy) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vy) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vy $ \pvy -> + fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) + VS.unsafeFreeze outv + + (Right vx) (Left y) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) + VS.unsafeFreeze outv + + (Right vx) (Right vy) + | VS.length vx == VS.length vy -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + VS.unsafeWith vy $ \pvy -> + fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) + VS.unsafeFreeze outv + | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) + -> Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM binopsList $ \arithop -> do + let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype + c_ss = varE (aboScalFun arithop arithtype) + c_sv = varE $ mkName (cnamebase ++ "_sv") + c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs") + | otherwise = [| flipOp $c_sv |] + c_vv = varE $ mkName (cnamebase ++ "_vv") + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM unopsList $ \arithop -> do + let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) + => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) + -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn + | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> i -> i) -- ss + -- int32 + -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv + -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn + | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +instance NumElt Int32 where + numEltAdd = addVectorInt32 + numEltSub = subVectorInt32 + numEltMul = mulVectorInt32 + numEltNeg = negVectorInt32 + numEltAbs = absVectorInt32 + numEltSignum = signumVectorInt32 + +instance NumElt Int64 where + numEltAdd = addVectorInt64 + numEltSub = subVectorInt64 + numEltMul = mulVectorInt64 + numEltNeg = negVectorInt64 + numEltAbs = absVectorInt64 + numEltSignum = signumVectorInt64 + +instance NumElt Float where + numEltAdd = addVectorFloat + numEltSub = subVectorFloat + numEltMul = mulVectorFloat + numEltNeg = negVectorFloat + numEltAbs = absVectorFloat + numEltSignum = signumVectorFloat + +instance NumElt Double where + numEltAdd = addVectorDouble + numEltSub = subVectorDouble + numEltMul = mulVectorDouble + numEltNeg = negVectorDouble + numEltAbs = absVectorDouble + numEltSignum = signumVectorDouble + +instance NumElt Int where + numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv + numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv + numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv + numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 + numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 + numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 + +instance NumElt CInt where + numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv + numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv + numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv + numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 + numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 + numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..dbd9ddc --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -0,0 +1,33 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Nested.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Nested.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM binopsList $ \arithop -> do + let base = aboName arithop ++ "_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,guard (aboComm arithop == NonComm) >> + Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + forM unopsList $ \arithop -> do + let base = auoName arithop ++ "_" ++ atCName arithtype + ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs new file mode 100644 index 0000000..1b29770 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Internal.Arith.Lists where + +import Data.Int + +import Language.Haskell.TH + + +data Commutative = Comm | NonComm + deriving (Show, Eq) + +data ArithType = ArithType + { atType :: Name -- ''Int32 + , atCName :: String -- "i32" + } + +typesList :: [ArithType] +typesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ,ArithType ''Float "float" + ,ArithType ''Double "double" + ] + +data ArithBOp = ArithBOp + { aboName :: String -- "add" + , aboComm :: Commutative -- Comm + , aboScalFun :: ArithType -> Name -- \_ -> '(+) + } + +binopsList :: [ArithBOp] +binopsList = + [ArithBOp "add" Comm (\_ -> '(+)) + ,ArithBOp "sub" NonComm (\_ -> '(-)) + ,ArithBOp "mul" Comm (\_ -> '(*)) + ] + +data ArithUOp = ArithUOp + { auoName :: String -- "neg" + } + +unopsList :: [ArithUOp] +unopsList = + [ArithUOp "neg" + ,ArithUOp "abs" + ,ArithUOp "signum" + ] -- cgit v1.2.3-70-g09d2