aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal/Arith.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs42
1 files changed, 42 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 313c885..11cbba6 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -502,6 +502,20 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM intTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_ibinary_" ++ atCName arithtype
+ c_ss_str = varE (aiboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM floatTypesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -794,6 +808,34 @@ instance NumElt CInt where
numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
(scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+class NumElt a => IntElt a where
+ intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+
+instance IntElt Int32 where
+ intEltQuot = quotVectorInt32
+ intEltRem = remVectorInt32
+
+instance IntElt Int64 where
+ intEltQuot = quotVectorInt64
+ intEltRem = remVectorInt64
+
+instance IntElt Int where
+ intEltQuot = intWidBranch2 @Int quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @Int rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+instance IntElt CInt where
+ intEltQuot = intWidBranch2 @CInt quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @CInt rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a