aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-12 23:20:13 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-13 09:27:51 +0100
commited6acbe5f409aba2fb222693da567ce04b7c4e01 (patch)
treebecbef3f3afeed63c248f057dae6fef0cb6c6147
parentbcda5b7eb20874f948fbdc23b6daa3ebb792ffe0 (diff)
Implement quot/rem
-rw-r--r--cbits/arith.c42
-rw-r--r--cbits/arith_lists.h3
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs42
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs11
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists.hs27
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs3
-rw-r--r--src/Data/Array/Nested.hs20
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs4
10 files changed, 151 insertions, 9 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 9aed3b4..4646ca4 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -18,6 +18,7 @@
// These are the wrapper macros used in arith_lists.h. Preset them to empty to
// avoid having to touch macros unrelated to the particular operation set below.
#define LIST_BINOP(name, id, hsop)
+#define LIST_IBINOP(name, id, hsop)
#define LIST_FBINOP(name, id, hsop)
#define LIST_UNOP(name, id, _)
#define LIST_FUNOP(name, id, _)
@@ -410,6 +411,37 @@ enum binop_tag_t {
} \
}
+enum ibinop_tag_t {
+#undef LIST_IBINOP
+#define LIST_IBINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_IBINOP
+#define LIST_IBINOP(name, id, hsop)
+};
+
+#define ENTRY_IBINARY_STRIDED_OPS(typ) \
+ void oxarop_ibinary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
+ switch (tag) { \
+ case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ default: wrong_op("ibinary_sv_strided", tag); \
+ } \
+ } \
+ void oxarop_ibinary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
+ switch (tag) { \
+ case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ default: wrong_op("ibinary_vs_strided", tag); \
+ } \
+ } \
+ void oxarop_ibinary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ switch (tag) { \
+ case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ default: wrong_op("ibinary_vv_strided", tag); \
+ } \
+ }
+
enum fbinop_tag_t {
#undef LIST_FBINOP
#define LIST_FBINOP(name, id, hsop) name = id,
@@ -528,8 +560,9 @@ enum redop_tag_t {
* Generate all the functions *
*****************************************************************************/
+#define INT_TYPES_XLIST X(i32) X(i64)
#define FLOAT_TYPES_XLIST X(double) X(float)
-#define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST
+#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST
#define X(typ) \
COMM_OP_STRIDED(add, +, typ) \
@@ -554,6 +587,13 @@ NUM_TYPES_XLIST
#undef X
#define X(typ) \
+ NONCOMM_OP_STRIDED(quot, /, typ) \
+ NONCOMM_OP_STRIDED(rem, %, typ) \
+ ENTRY_IBINARY_STRIDED_OPS(typ)
+INT_TYPES_XLIST
+#undef X
+
+#define X(typ) \
NONCOMM_OP_STRIDED(fdiv, /, typ) \
PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \
PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index 58de65a..b651b62 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -2,6 +2,9 @@ LIST_BINOP(BO_ADD, 1, +)
LIST_BINOP(BO_SUB, 2, -)
LIST_BINOP(BO_MUL, 3, *)
+LIST_IBINOP(IB_QUOT, 1, quot)
+LIST_IBINOP(IB_REM, 2, rem)
+
LIST_FBINOP(FB_DIV, 1, /)
LIST_FBINOP(FB_POW, 2, **)
LIST_FBINOP(FB_LOGBASE, 3, logBase)
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 313c885..11cbba6 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -502,6 +502,20 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM intTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_ibinary_" ++ atCName arithtype
+ c_ss_str = varE (aiboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
$(fmap concat . forM floatTypesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
@@ -794,6 +808,34 @@ instance NumElt CInt where
numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
(scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+class NumElt a => IntElt a where
+ intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+
+instance IntElt Int32 where
+ intEltQuot = quotVectorInt32
+ intEltRem = remVectorInt32
+
+instance IntElt Int64 where
+ intEltQuot = quotVectorInt64
+ intEltRem = remVectorInt64
+
+instance IntElt Int where
+ intEltQuot = intWidBranch2 @Int quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @Int rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+instance IntElt CInt where
+ intEltQuot = intWidBranch2 @CInt quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @CInt rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index fa89766..15fbc79 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -25,6 +25,12 @@ $(do
,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
]
+ let importsInt ttyp tyn =
+ [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |])
+ ]
+
let importsFloat ttyp tyn =
[("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
@@ -38,5 +44,6 @@ $(do
| arithtype <- types
, (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)]
decs1 <- generate typesList importsScal
- decs2 <- generate floatTypesList importsFloat
- return (decs1 ++ decs2))
+ decs2 <- generate intTypesList importsInt
+ decs3 <- generate floatTypesList importsFloat
+ return (decs1 ++ decs2 ++ decs3))
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
index a284bc1..370b708 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
@@ -14,6 +14,12 @@ data ArithType = ArithType
, atCName :: String -- "i32"
}
+intTypesList :: [ArithType]
+intTypesList =
+ [ArithType ''Int32 "i32"
+ ,ArithType ''Int64 "i64"
+ ]
+
floatTypesList :: [ArithType]
floatTypesList =
[ArithType ''Float "float"
@@ -21,11 +27,7 @@ floatTypesList =
]
typesList :: [ArithType]
-typesList =
- [ArithType ''Int32 "i32"
- ,ArithType ''Int64 "i64"
- ]
- ++ floatTypesList
+typesList = intTypesList ++ floatTypesList
-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
$(genArithDataType Binop "ArithBOp")
@@ -42,6 +44,21 @@ $(do clauses <- readArithLists Binop
,return $ FunD (mkName "aboNumOp") clauses])
+-- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded)
+$(genArithDataType IBinop "ArithIBOp")
+
+$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3))
+$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum")
+
+$(do clauses <- readArithLists IBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |]
+ ,return $ FunD (mkName "aiboNumOp") clauses])
+
+
-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
$(genArithDataType FBinop "ArithFBOp")
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
index 8b7d05f..a156e29 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
@@ -10,7 +10,7 @@ import Language.Haskell.TH.Syntax
import Text.Read
-data OpKind = Binop | FBinop | Unop | FUnop | Redop
+data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop
deriving (Show, Eq)
readArithLists :: OpKind
@@ -48,6 +48,7 @@ readArithLists targetkind fop fcombine = do
parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
parseKind "BINOP" = Just Binop
+ parseKind "IBINOP" = Just IBinop
parseKind "FBINOP" = Just FBinop
parseKind "UNOP" = Just Unop
parseKind "FUNOP" = Just FUnop
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index bef83d1..58c0c71 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -21,6 +21,10 @@ module Data.Array.Nested (
rtoXArrayPrim, rfromXArrayPrim,
rcastToShaped, rtoMixed, rcastToMixed,
rfromOrthotope, rtoOrthotope,
+ -- ** Additional arithmetic operations
+ --
+ -- $integralRealFloat
+ rquotArray, rremArray,
-- * Shaped arrays
Shaped(Shaped),
@@ -43,6 +47,10 @@ module Data.Array.Nested (
stoXArrayPrim, sfromXArrayPrim,
stoRanked, stoMixed, scastToMixed,
sfromOrthotope, stoOrthotope,
+ -- ** Additional arithmetic operations
+ --
+ -- $integralRealFloat
+ squotArray, sremArray,
-- * Mixed arrays
Mixed,
@@ -68,6 +76,10 @@ module Data.Array.Nested (
mcastSafe, SafeMCast, SafeMCastSpec(..),
mtoRanked, mcastToShaped,
castCastable, Castable(..),
+ -- ** Additional arithmetic operations
+ --
+ -- $integralRealFloat
+ mquotArray, mremArray,
-- * Array elements
Elt,
@@ -102,3 +114,11 @@ import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped
import Foreign.Storable
import GHC.TypeLits
+
+-- $integralRealFloat
+--
+-- These functions separate top-level functions, and not exposed in instances
+-- for 'RealFloat' and 'Integral', because those classes include a variety of
+-- other functions that make no sense for arrays.
+-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but
+-- having 'Num', 'Fractional' and 'Floating' available is just too useful.
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 08f97f0..7e1f100 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -265,6 +265,10 @@ instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
log1pexp = mliftNumElt1 floatEltLog1pexp
log1mexp = mliftNumElt1 floatEltLog1mexp
+mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
+mquotArray = mliftNumElt2 intEltQuot
+mremArray = mliftNumElt2 intEltRem
+
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 3bdd44e..4fb29e0 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -236,6 +236,10 @@ instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
log1pexp = arithPromoteRanked GHC.Float.log1pexp
log1mexp = arithPromoteRanked GHC.Float.log1mexp
+rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+rquotArray = arithPromoteRanked2 mquotArray
+rremArray = arithPromoteRanked2 mremArray
+
remptyArray :: KnownElt a => Ranked 1 a
remptyArray = mtoRanked (memptyArray ZSX)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index f75519f..ed616cf 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -234,6 +234,10 @@ instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
log1pexp = arithPromoteShaped GHC.Float.log1pexp
log1mexp = arithPromoteShaped GHC.Float.log1mexp
+squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+squotArray = arithPromoteShaped2 mquotArray
+sremArray = arithPromoteShaped2 mremArray
+
semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
semptyArray sh = Shaped (memptyArray (shCvtSX sh))