diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-12 23:20:13 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 09:27:51 +0100 | 
| commit | ed6acbe5f409aba2fb222693da567ce04b7c4e01 (patch) | |
| tree | becbef3f3afeed63c248f057dae6fef0cb6c6147 /src/Data/Array/Mixed | |
| parent | bcda5b7eb20874f948fbdc23b6daa3ebb792ffe0 (diff) | |
Implement quot/rem
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 42 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 11 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 27 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 3 | 
4 files changed, 75 insertions, 8 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 313c885..11cbba6 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -502,6 +502,20 @@ $(fmap concat . forM typesList $ \arithtype -> do                 ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]                     return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM intTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_ibinary_" ++ atCName arithtype +          c_ss_str = varE (aiboNumOp arithop) +          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) +          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) +          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] +                   return $ FunD name [Clause [] (NormalB body) []]]) +  $(fmap concat . forM floatTypesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -794,6 +808,34 @@ instance NumElt CInt where    numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32                                                   (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 +class NumElt a => IntElt a where +  intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + +instance IntElt Int32 where +  intEltQuot = quotVectorInt32 +  intEltRem = remVectorInt32 + +instance IntElt Int64 where +  intEltQuot = quotVectorInt64 +  intEltRem = remVectorInt64 + +instance IntElt Int where +  intEltQuot = intWidBranch2 @Int quot +                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) +                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) +  intEltRem = intWidBranch2 @Int rem +                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) +                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +instance IntElt CInt where +  intEltQuot = intWidBranch2 @CInt quot +                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) +                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) +  intEltRem = intWidBranch2 @CInt rem +                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) +                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) +  class NumElt a => FloatElt a where    floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a    floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index fa89766..15fbc79 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -25,6 +25,12 @@ $(do          ,("dotprodinner_" ++ tyn,            [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ] +  let importsInt ttyp tyn = +        [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> IO () |]) +        ] +    let importsFloat ttyp tyn =          [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) @@ -38,5 +44,6 @@ $(do            | arithtype <- types            , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)]    decs1 <- generate typesList importsScal -  decs2 <- generate floatTypesList importsFloat -  return (decs1 ++ decs2)) +  decs2 <- generate intTypesList importsInt +  decs3 <- generate floatTypesList importsFloat +  return (decs1 ++ decs2 ++ decs3)) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs index a284bc1..370b708 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Lists.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -14,6 +14,12 @@ data ArithType = ArithType    , atCName :: String  -- "i32"    } +intTypesList :: [ArithType] +intTypesList = +  [ArithType ''Int32 "i32" +  ,ArithType ''Int64 "i64" +  ] +  floatTypesList :: [ArithType]  floatTypesList =    [ArithType ''Float "float" @@ -21,11 +27,7 @@ floatTypesList =    ]  typesList :: [ArithType] -typesList = -  [ArithType ''Int32 "i32" -  ,ArithType ''Int64 "i64" -  ] -  ++ floatTypesList +typesList = intTypesList ++ floatTypesList  -- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)  $(genArithDataType Binop "ArithBOp") @@ -42,6 +44,21 @@ $(do clauses <- readArithLists Binop                ,return $ FunD (mkName "aboNumOp") clauses]) +-- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded) +$(genArithDataType IBinop "ArithIBOp") + +$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3)) +$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum") + +$(do clauses <- readArithLists IBinop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |] +              ,return $ FunD (mkName "aiboNumOp") clauses]) + +  -- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)  $(genArithDataType FBinop "ArithFBOp") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs index 8b7d05f..a156e29 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -10,7 +10,7 @@ import Language.Haskell.TH.Syntax  import Text.Read -data OpKind = Binop | FBinop | Unop | FUnop | Redop +data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop    deriving (Show, Eq)  readArithLists :: OpKind @@ -48,6 +48,7 @@ readArithLists targetkind fop fcombine = do      parseField s = break (`elem` ",)") (dropWhile (== ' ') s)      parseKind "BINOP" = Just Binop +    parseKind "IBINOP" = Just IBinop      parseKind "FBINOP" = Just FBinop      parseKind "UNOP" = Just Unop      parseKind "FUNOP" = Just FUnop | 
