aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
commit34a9ac8e4497e776c3ca499c41ef749f4edf8383 (patch)
treef2b2e34d830d66d23ae19909c71771e810c262d0 /src/Data/Array/Nested/Internal/Arith.hs
parent85593969debadbf11ad3c159de71e7b480ca367c (diff)
Refactor C interface to pass operation as enum
This is hmatrix style, less proliferation of functions as the number of ops increases
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs77
1 files changed, 48 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 4bfc043..7484455 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -182,14 +182,13 @@ class NumElt a where
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype
- c_ss = varE (aboScalFun arithop arithtype)
- c_sv = varE $ mkName (cnamebase ++ "_sv")
- c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs")
- | otherwise = [| flipOp $c_sv |]
- c_vv = varE $ mkName (cnamebase ++ "_vv")
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss = varE (aboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
@@ -197,9 +196,9 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM unopsList $ \arithop -> do
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)
+ c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
@@ -207,10 +206,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM redopsList $ \redop -> do
- let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype)
- c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv")
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
@@ -297,21 +296,41 @@ instance NumElt Double where
numEltProduct1Inner = product1VectorDouble
instance NumElt Int where
- numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
- numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
- numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
- numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64
- numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64
- numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64
- numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
- numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
instance NumElt CInt where
- numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
- numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
- numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
- numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64
- numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64
- numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64
- numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
- numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))