diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 77 |
1 files changed, 48 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4bfc043..7484455 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -182,14 +182,13 @@ class NumElt a where $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM binopsList $ \arithop -> do + fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype - c_ss = varE (aboScalFun arithop arithtype) - c_sv = varE $ mkName (cnamebase ++ "_sv") - c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs") - | otherwise = [| flipOp $c_sv |] - c_vv = varE $ mkName (cnamebase ++ "_vv") + cnamebase = "c_binary_" ++ atCName arithtype + c_ss = varE (aboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] @@ -197,9 +196,9 @@ $(fmap concat . forM typesList $ \arithtype -> do $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM unopsList $ \arithop -> do + fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) + c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] @@ -207,10 +206,10 @@ $(fmap concat . forM typesList $ \arithtype -> do $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM redopsList $ \redop -> do - let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) - c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] @@ -297,21 +296,41 @@ instance NumElt Double where numEltProduct1Inner = product1VectorDouble instance NumElt Int where - numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv - numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv - numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv - numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 - numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 - numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 - numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 - numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 + numEltAdd = intWidBranch2 @Int (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @Int (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @Int (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) instance NumElt CInt where - numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv - numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv - numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv - numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 - numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 - numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 - numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 - numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 + numEltAdd = intWidBranch2 @CInt (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @CInt (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @CInt (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) |