diff options
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 313c885..11cbba6 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -502,6 +502,20 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM intTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_ibinary_" ++ atCName arithtype + c_ss_str = varE (aiboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] + return $ FunD name [Clause [] (NormalB body) []]]) + $(fmap concat . forM floatTypesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -794,6 +808,34 @@ instance NumElt CInt where numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 +class NumElt a => IntElt a where + intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + +instance IntElt Int32 where + intEltQuot = quotVectorInt32 + intEltRem = remVectorInt32 + +instance IntElt Int64 where + intEltQuot = quotVectorInt64 + intEltRem = remVectorInt64 + +instance IntElt Int where + intEltQuot = intWidBranch2 @Int quot + (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + intEltRem = intWidBranch2 @Int rem + (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) + (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +instance IntElt CInt where + intEltQuot = intWidBranch2 @CInt quot + (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + intEltRem = intWidBranch2 @CInt rem + (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) + (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + class NumElt a => FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a |