diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 14:57:34 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 14:57:34 +0200 |
commit | e80b2593edc3d216905279ebcfa797593a1efbfc (patch) | |
tree | 5e5057e03f35369983f6600efc59c438c0cf2366 /src/Data/Array | |
parent | 2ac16efe59051e0cdeb37422ab579c8d354d562a (diff) |
Fast Fractional ops via C code
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 16 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 56 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 18 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 31 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 4 |
5 files changed, 104 insertions, 21 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 94f08bf..ef2ad6b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1048,12 +1048,12 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where signum = mliftNumElt1 numEltSignum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" -instance (NumElt a, PrimElt a, Fractional a) => Fractional (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftPrim recip - (/) = mliftPrim2 (/) + recip = mliftNumElt1 floatEltRecip + (/) = mliftNumElt2 floatEltDiv -instance (NumElt a, PrimElt a, Floating a) => Floating (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" exp = mliftPrim exp log = mliftPrim log @@ -1367,12 +1367,12 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where signum = arithPromoteRanked signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" -instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" recip = arithPromoteRanked recip (/) = arithPromoteRanked2 (/) -instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" exp = arithPromoteRanked exp log = arithPromoteRanked log @@ -1698,12 +1698,12 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where signum = arithPromoteShaped signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" -instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) -instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" exp = arithPromoteShaped exp log = arithPromoteShaped log diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 7484455..07d5d8a 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -170,16 +170,6 @@ flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v -class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -194,6 +184,20 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_fbinary_" ++ atCName arithtype + c_ss = varE (afboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -204,6 +208,16 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -255,6 +269,16 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 | otherwise = error "Unsupported Int width" +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + instance NumElt Int32 where numEltAdd = addVectorInt32 numEltSub = subVectorInt32 @@ -334,3 +358,15 @@ instance NumElt CInt where numEltProduct1Inner = intWidBranchRed @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where + floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where + floatEltDiv = divVectorFloat + floatEltRecip = recipVectorFloat + +instance FloatElt Double where + floatEltDiv = divVectorDouble + floatEltRecip = recipVectorDouble diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index 49effa1..ac83188 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -24,12 +24,30 @@ $(fmap concat . forM typesList $ \arithtype -> do [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) ]) +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "fbinary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) let base = "unary_" ++ atCName arithtype pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "funary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) let base = "reduce_" ++ atCName arithtype diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 91e50ad..ce2836d 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -14,13 +14,18 @@ data ArithType = ArithType , atCName :: String -- "i32" } +floatTypesList :: [ArithType] +floatTypesList = + [ArithType ''Float "float" + ,ArithType ''Double "double" + ] + typesList :: [ArithType] typesList = [ArithType ''Int32 "i32" ,ArithType ''Int64 "i64" - ,ArithType ''Float "float" - ,ArithType ''Double "double" ] + ++ floatTypesList -- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) $(genArithDataType Binop "ArithBOp") @@ -37,6 +42,21 @@ $(do clauses <- readArithLists Binop ,return $ FunD (mkName "aboNumOp") clauses]) +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] + ,return $ FunD (mkName "afboNumOp") clauses]) + + -- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) $(genArithDataType Unop "ArithUOp") @@ -44,6 +64,13 @@ $(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) $(genArithEnumFun Unop ''ArithUOp "auoEnum") +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + -- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) $(genArithDataType Redop "ArithRedOp") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs index b748b97..b40a066 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs @@ -9,7 +9,7 @@ import Language.Haskell.TH import Text.Read -data OpKind = Binop | Unop | Redop +data OpKind = Binop | FBinop | Unop | FUnop | Redop deriving (Show, Eq) readArithLists :: OpKind @@ -46,7 +46,9 @@ readArithLists targetkind fop fcombine = do parseField s = break (`elem` ",)") (dropWhile (== ' ') s) parseKind "BINOP" = Just Binop + parseKind "FBINOP" = Just FBinop parseKind "UNOP" = Just Unop + parseKind "FUNOP" = Just FUnop parseKind "REDOP" = Just Redop parseKind _ = Nothing |