diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-12 23:20:13 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 09:27:51 +0100 |
commit | ed6acbe5f409aba2fb222693da567ce04b7c4e01 (patch) | |
tree | becbef3f3afeed63c248f057dae6fef0cb6c6147 | |
parent | bcda5b7eb20874f948fbdc23b6daa3ebb792ffe0 (diff) |
Implement quot/rem
-rw-r--r-- | cbits/arith.c | 42 | ||||
-rw-r--r-- | cbits/arith_lists.h | 3 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 42 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 11 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 27 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 20 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 4 |
10 files changed, 151 insertions, 9 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 9aed3b4..4646ca4 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -18,6 +18,7 @@ // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. #define LIST_BINOP(name, id, hsop) +#define LIST_IBINOP(name, id, hsop) #define LIST_FBINOP(name, id, hsop) #define LIST_UNOP(name, id, _) #define LIST_FUNOP(name, id, _) @@ -410,6 +411,37 @@ enum binop_tag_t { } \ } +enum ibinop_tag_t { +#undef LIST_IBINOP +#define LIST_IBINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_IBINOP +#define LIST_IBINOP(name, id, hsop) +}; + +#define ENTRY_IBINARY_STRIDED_OPS(typ) \ + void oxarop_ibinary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("ibinary_sv_strided", tag); \ + } \ + } \ + void oxarop_ibinary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + default: wrong_op("ibinary_vs_strided", tag); \ + } \ + } \ + void oxarop_ibinary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("ibinary_vv_strided", tag); \ + } \ + } + enum fbinop_tag_t { #undef LIST_FBINOP #define LIST_FBINOP(name, id, hsop) name = id, @@ -528,8 +560,9 @@ enum redop_tag_t { * Generate all the functions * *****************************************************************************/ +#define INT_TYPES_XLIST X(i32) X(i64) #define FLOAT_TYPES_XLIST X(double) X(float) -#define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST +#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST #define X(typ) \ COMM_OP_STRIDED(add, +, typ) \ @@ -554,6 +587,13 @@ NUM_TYPES_XLIST #undef X #define X(typ) \ + NONCOMM_OP_STRIDED(quot, /, typ) \ + NONCOMM_OP_STRIDED(rem, %, typ) \ + ENTRY_IBINARY_STRIDED_OPS(typ) +INT_TYPES_XLIST +#undef X + +#define X(typ) \ NONCOMM_OP_STRIDED(fdiv, /, typ) \ PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \ PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \ diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index 58de65a..b651b62 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -2,6 +2,9 @@ LIST_BINOP(BO_ADD, 1, +) LIST_BINOP(BO_SUB, 2, -) LIST_BINOP(BO_MUL, 3, *) +LIST_IBINOP(IB_QUOT, 1, quot) +LIST_IBINOP(IB_REM, 2, rem) + LIST_FBINOP(FB_DIV, 1, /) LIST_FBINOP(FB_POW, 2, **) LIST_FBINOP(FB_LOGBASE, 3, logBase) diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 313c885..11cbba6 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -502,6 +502,20 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM intTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_ibinary_" ++ atCName arithtype + c_ss_str = varE (aiboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] + return $ FunD name [Clause [] (NormalB body) []]]) + $(fmap concat . forM floatTypesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -794,6 +808,34 @@ instance NumElt CInt where numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 +class NumElt a => IntElt a where + intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + +instance IntElt Int32 where + intEltQuot = quotVectorInt32 + intEltRem = remVectorInt32 + +instance IntElt Int64 where + intEltQuot = quotVectorInt64 + intEltRem = remVectorInt64 + +instance IntElt Int where + intEltQuot = intWidBranch2 @Int quot + (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + intEltRem = intWidBranch2 @Int rem + (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) + (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +instance IntElt CInt where + intEltQuot = intWidBranch2 @CInt quot + (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + intEltRem = intWidBranch2 @CInt rem + (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) + (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + class NumElt a => FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index fa89766..15fbc79 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -25,6 +25,12 @@ $(do ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ] + let importsInt ttyp tyn = + [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ] + let importsFloat ttyp tyn = [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) @@ -38,5 +44,6 @@ $(do | arithtype <- types , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] decs1 <- generate typesList importsScal - decs2 <- generate floatTypesList importsFloat - return (decs1 ++ decs2)) + decs2 <- generate intTypesList importsInt + decs3 <- generate floatTypesList importsFloat + return (decs1 ++ decs2 ++ decs3)) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs index a284bc1..370b708 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Lists.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -14,6 +14,12 @@ data ArithType = ArithType , atCName :: String -- "i32" } +intTypesList :: [ArithType] +intTypesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ] + floatTypesList :: [ArithType] floatTypesList = [ArithType ''Float "float" @@ -21,11 +27,7 @@ floatTypesList = ] typesList :: [ArithType] -typesList = - [ArithType ''Int32 "i32" - ,ArithType ''Int64 "i64" - ] - ++ floatTypesList +typesList = intTypesList ++ floatTypesList -- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) $(genArithDataType Binop "ArithBOp") @@ -42,6 +44,21 @@ $(do clauses <- readArithLists Binop ,return $ FunD (mkName "aboNumOp") clauses]) +-- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded) +$(genArithDataType IBinop "ArithIBOp") + +$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3)) +$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum") + +$(do clauses <- readArithLists IBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |] + ,return $ FunD (mkName "aiboNumOp") clauses]) + + -- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) $(genArithDataType FBinop "ArithFBOp") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs index 8b7d05f..a156e29 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -10,7 +10,7 @@ import Language.Haskell.TH.Syntax import Text.Read -data OpKind = Binop | FBinop | Unop | FUnop | Redop +data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop deriving (Show, Eq) readArithLists :: OpKind @@ -48,6 +48,7 @@ readArithLists targetkind fop fcombine = do parseField s = break (`elem` ",)") (dropWhile (== ' ') s) parseKind "BINOP" = Just Binop + parseKind "IBINOP" = Just IBinop parseKind "FBINOP" = Just FBinop parseKind "UNOP" = Just Unop parseKind "FUNOP" = Just FUnop diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index bef83d1..58c0c71 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -21,6 +21,10 @@ module Data.Array.Nested ( rtoXArrayPrim, rfromXArrayPrim, rcastToShaped, rtoMixed, rcastToMixed, rfromOrthotope, rtoOrthotope, + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + rquotArray, rremArray, -- * Shaped arrays Shaped(Shaped), @@ -43,6 +47,10 @@ module Data.Array.Nested ( stoXArrayPrim, sfromXArrayPrim, stoRanked, stoMixed, scastToMixed, sfromOrthotope, stoOrthotope, + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + squotArray, sremArray, -- * Mixed arrays Mixed, @@ -68,6 +76,10 @@ module Data.Array.Nested ( mcastSafe, SafeMCast, SafeMCastSpec(..), mtoRanked, mcastToShaped, castCastable, Castable(..), + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + mquotArray, mremArray, -- * Array elements Elt, @@ -102,3 +114,11 @@ import Data.Array.Nested.Internal.Shape import Data.Array.Nested.Internal.Shaped import Foreign.Storable import GHC.TypeLits + +-- $integralRealFloat +-- +-- These functions separate top-level functions, and not exposed in instances +-- for 'RealFloat' and 'Integral', because those classes include a variety of +-- other functions that make no sense for arrays. +-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but +-- having 'Num', 'Fractional' and 'Floating' available is just too useful. diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 08f97f0..7e1f100 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -265,6 +265,10 @@ instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where log1pexp = mliftNumElt1 floatEltLog1pexp log1mexp = mliftNumElt1 floatEltLog1mexp +mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +mquotArray = mliftNumElt2 intEltQuot +mremArray = mliftNumElt2 intEltRem + -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 3bdd44e..4fb29e0 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -236,6 +236,10 @@ instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where log1pexp = arithPromoteRanked GHC.Float.log1pexp log1mexp = arithPromoteRanked GHC.Float.log1mexp +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = arithPromoteRanked2 mquotArray +rremArray = arithPromoteRanked2 mremArray + remptyArray :: KnownElt a => Ranked 1 a remptyArray = mtoRanked (memptyArray ZSX) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index f75519f..ed616cf 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -234,6 +234,10 @@ instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where log1pexp = arithPromoteShaped GHC.Float.log1pexp log1mexp = arithPromoteShaped GHC.Float.log1mexp +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = arithPromoteShaped2 mquotArray +sremArray = arithPromoteShaped2 mremArray + semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a semptyArray sh = Shaped (memptyArray (shCvtSX sh)) |