diff options
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 435 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 55 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 78 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 82 |
4 files changed, 650 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs new file mode 100644 index 0000000..cf6820b --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -0,0 +1,435 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Mixed.Internal.Arith.Foreign +import Data.Array.Mixed.Internal.Arith.Lists + + +liftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just prefixSz <- stridesDense sh strides = + let vec' = f (VS.slice offset prefixSz vec) + in RS.A (RG.A sh (OI.T strides 0 vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +liftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f + arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) + arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) + | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs + | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of + (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) + in RS.A (RG.A sh1 (OI.T strides1 0 vec')) + (Just 1, Just n) -> -- scalar * dense + RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) + (Just n, Just 1) -> -- dense * scalar + RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) + (_, _) -> -- fallback case + RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = + -- sort dimensions on their stride, ascending, dropping any zero strides + case dropWhile ((== 0) . fst) (sort (zip str sh)) of + [] -> Just 1 + (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' + _ -> Nothing -- if the smallest stride is not 1, it will never be dense + where + -- Given size of currently densely covered region at beginning of the + -- array, the remaining shape vector and the corresponding remaining stride + -- vector, return whether this all together covers a dense prefix of the + -- array. If it does, return the number of elements in this prefix. + checkCover :: Int -> [Int] -> [Int] -> Maybe Int + checkCover block [] [] = Just block + checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' + checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +{-# NOINLINE vectorOp1 #-} +vectorOp1 :: forall a b. Storable a + => (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO ()) + -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length v) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith v $ \pv -> + f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) + VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} +vectorOp2 :: forall a b. Storable a + => (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv + -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs + -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv + -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases + (Left x) (Left y) -> VS.singleton (fss x y) + + (Left x) (Right vy) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vy) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vy $ \pvy -> + fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) + VS.unsafeFreeze outv + + (Right vx) (Left y) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) + VS.unsafeFreeze outv + + (Right vx) (Right vy) + | VS.length vx == VS.length vy -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + VS.unsafeWith vy $ \pvy -> + fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) + VS.unsafeFreeze outv + | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) + | null sh = error "unreachable" + | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) + | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) + -- now the input array is nonempty + | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) + | last strides == 0 = + liftVEltwise1 sn + (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) + (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) + -- now there is useful work along the inner dimension + | otherwise = + let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those + (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) + ndimsF = length shF + in unsafePerformIO $ do + outv <- VSM.unsafeNew (product (init shF)) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> + fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) + RS.fromVector (init sh) <$> VS.unsafeFreeze outv + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) + -> Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_binary_" ++ atCName arithtype + c_ss = varE (aboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_fbinary_" ++ atCName arithtype + c_ss = varE (afboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] + return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) + => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) + -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn + | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> i -> i) -- ss + -- int32 + -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv + -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn + | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 + | otherwise = error "Unsupported Int width" + +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + +instance NumElt Int32 where + numEltAdd = addVectorInt32 + numEltSub = subVectorInt32 + numEltMul = mulVectorInt32 + numEltNeg = negVectorInt32 + numEltAbs = absVectorInt32 + numEltSignum = signumVectorInt32 + numEltSum1Inner = sum1VectorInt32 + numEltProduct1Inner = product1VectorInt32 + +instance NumElt Int64 where + numEltAdd = addVectorInt64 + numEltSub = subVectorInt64 + numEltMul = mulVectorInt64 + numEltNeg = negVectorInt64 + numEltAbs = absVectorInt64 + numEltSignum = signumVectorInt64 + numEltSum1Inner = sum1VectorInt64 + numEltProduct1Inner = product1VectorInt64 + +instance NumElt Float where + numEltAdd = addVectorFloat + numEltSub = subVectorFloat + numEltMul = mulVectorFloat + numEltNeg = negVectorFloat + numEltAbs = absVectorFloat + numEltSignum = signumVectorFloat + numEltSum1Inner = sum1VectorFloat + numEltProduct1Inner = product1VectorFloat + +instance NumElt Double where + numEltAdd = addVectorDouble + numEltSub = subVectorDouble + numEltMul = mulVectorDouble + numEltNeg = negVectorDouble + numEltAbs = absVectorDouble + numEltSignum = signumVectorDouble + numEltSum1Inner = sum1VectorDouble + numEltProduct1Inner = product1VectorDouble + +instance NumElt Int where + numEltAdd = intWidBranch2 @Int (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @Int (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @Int (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +instance NumElt CInt where + numEltAdd = intWidBranch2 @CInt (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @CInt (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @CInt (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where + floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a + floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a + floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a + floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a + floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where + floatEltDiv = divVectorFloat + floatEltPow = powVectorFloat + floatEltLogbase = logbaseVectorFloat + floatEltRecip = recipVectorFloat + floatEltExp = expVectorFloat + floatEltLog = logVectorFloat + floatEltSqrt = sqrtVectorFloat + floatEltSin = sinVectorFloat + floatEltCos = cosVectorFloat + floatEltTan = tanVectorFloat + floatEltAsin = asinVectorFloat + floatEltAcos = acosVectorFloat + floatEltAtan = atanVectorFloat + floatEltSinh = sinhVectorFloat + floatEltCosh = coshVectorFloat + floatEltTanh = tanhVectorFloat + floatEltAsinh = asinhVectorFloat + floatEltAcosh = acoshVectorFloat + floatEltAtanh = atanhVectorFloat + floatEltLog1p = log1pVectorFloat + floatEltExpm1 = expm1VectorFloat + floatEltLog1pexp = log1pexpVectorFloat + floatEltLog1mexp = log1mexpVectorFloat + +instance FloatElt Double where + floatEltDiv = divVectorDouble + floatEltPow = powVectorDouble + floatEltLogbase = logbaseVectorDouble + floatEltRecip = recipVectorDouble + floatEltExp = expVectorDouble + floatEltLog = logVectorDouble + floatEltSqrt = sqrtVectorDouble + floatEltSin = sinVectorDouble + floatEltCos = cosVectorDouble + floatEltTan = tanVectorDouble + floatEltAsin = asinVectorDouble + floatEltAcos = acosVectorDouble + floatEltAtan = atanVectorDouble + floatEltSinh = sinhVectorDouble + floatEltCosh = coshVectorDouble + floatEltTanh = tanhVectorDouble + floatEltAsinh = asinhVectorDouble + floatEltAcosh = acoshVectorDouble + floatEltAtanh = atanhVectorDouble + floatEltLog1p = log1pVectorDouble + floatEltExpm1 = expm1VectorDouble + floatEltLog1pexp = log1pexpVectorDouble + floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..6fc7229 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "binary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "fbinary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "unary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "funary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "reduce_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs new file mode 100644 index 0000000..a284bc1 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists.TH + + +data ArithType = ArithType + { atType :: Name -- ''Int32 + , atCName :: String -- "i32" + } + +floatTypesList :: [ArithType] +floatTypesList = + [ArithType ''Float "float" + ,ArithType ''Double "double" + ] + +typesList :: [ArithType] +typesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ] + ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] + ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] + ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..8b7d05f --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Mixed.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | FBinop | Unop | FUnop | Redop + deriving (Show, Eq) + +readArithLists :: OpKind + -> (String -> Int -> String -> Q a) + -> ([a] -> Q r) + -> Q r +readArithLists targetkind fop fcombine = do + addDependentFile "cbits/arith_lists.h" + lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + + mvals <- forM lns $ \line -> do + if null (dropWhile (== ' ') line) + then return Nothing + else do let (kind, name, num, aux) = parseLine line + if kind == targetkind + then Just <$> fop name num aux + else return Nothing + + fcombine (catMaybes mvals) + where + parseLine s0 + | ("LIST_", s1) <- splitAt 5 s0 + , (kindstr, '(' : s2) <- break (== '(') s1 + , (f1, ',' : s3) <- parseField s2 + , (f2, ',' : s4) <- parseField s3 + , (f3, ')' : _) <- parseField s4 + , Just kind <- parseKind kindstr + , let name = f1 + , Just num <- readMaybe f2 + , let aux = f3 + = (kind, name, num, aux) + | otherwise + = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + + parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + + parseKind "BINOP" = Just Binop + parseKind "FBINOP" = Just FBinop + parseKind "UNOP" = Just Unop + parseKind "FUNOP" = Just FUnop + parseKind "REDOP" = Just Redop + parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do + cons <- readArithLists kind + (\name _num _ -> return $ NormalC (mkName name) []) + return + return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do + clauses <- readArithLists kind + (\name _num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (StringL (nametrans name)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) + ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do + clauses <- readArithLists kind + (\name num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (IntegerL (fromIntegral num)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) + ,FunD (mkName funname) clauses] |