From ed6acbe5f409aba2fb222693da567ce04b7c4e01 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Wed, 12 Mar 2025 23:20:13 +0100
Subject: Implement quot/rem

---
 src/Data/Array/Mixed/Internal/Arith.hs | 42 ++++++++++++++++++++++++++++++++++
 1 file changed, 42 insertions(+)

(limited to 'src/Data/Array/Mixed/Internal/Arith.hs')

diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 313c885..11cbba6 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -502,6 +502,20 @@ $(fmap concat . forM typesList $ \arithtype -> do
                ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
                    return $ FunD name [Clause [] (NormalB body) []]])
 
+$(fmap concat . forM intTypesList $ \arithtype -> do
+    let ttyp = conT (atType arithtype)
+    fmap concat . forM [minBound..maxBound] $ \arithop -> do
+      let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+          cnamebase = "c_ibinary_" ++ atCName arithtype
+          c_ss_str = varE (aiboNumOp arithop)
+          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+      sequence [SigD name <$>
+                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+                   return $ FunD name [Clause [] (NormalB body) []]])
+
 $(fmap concat . forM floatTypesList $ \arithtype -> do
     let ttyp = conT (atType arithtype)
     fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -794,6 +808,34 @@ instance NumElt CInt where
   numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
                                                  (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
 
+class NumElt a => IntElt a where
+  intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+  intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+
+instance IntElt Int32 where
+  intEltQuot = quotVectorInt32
+  intEltRem = remVectorInt32
+
+instance IntElt Int64 where
+  intEltQuot = quotVectorInt64
+  intEltRem = remVectorInt64
+
+instance IntElt Int where
+  intEltQuot = intWidBranch2 @Int quot
+                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+  intEltRem = intWidBranch2 @Int rem
+                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+instance IntElt CInt where
+  intEltQuot = intWidBranch2 @CInt quot
+                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+  intEltRem = intWidBranch2 @CInt rem
+                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
 class NumElt a => FloatElt a where
   floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
   floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
-- 
cgit v1.2.3-70-g09d2