diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 | 
| commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
| tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed/Internal | |
| parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) | |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 435 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 55 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 78 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 82 | 
4 files changed, 650 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs new file mode 100644 index 0000000..cf6820b --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -0,0 +1,435 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Mixed.Internal.Arith.Foreign +import Data.Array.Mixed.Internal.Arith.Lists + + +liftVEltwise1 :: Storable a +              => SNat n +              -> (VS.Vector a -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +  | Just prefixSz <- stridesDense sh strides = +      let vec' = f (VS.slice offset prefixSz vec) +      in RS.A (RG.A sh (OI.T strides 0 vec')) +  | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +liftVEltwise2 :: Storable a +              => SNat n +              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f +    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) +    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) +  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs +  | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of +      (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars +        let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) +        in RS.A (RG.A sh1 (OI.T strides1 0 vec')) +      (Just 1, Just n) ->  -- scalar * dense +        RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) +      (Just n, Just 1) ->  -- dense * scalar +        RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) +      (_, _) ->  -- fallback case +        RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = +  -- sort dimensions on their stride, ascending, dropping any zero strides +  case dropWhile ((== 0) . fst) (sort (zip str sh)) of +    [] -> Just 1 +    (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' +    _ -> Nothing  -- if the smallest stride is not 1, it will never be dense +  where +    -- Given size of currently densely covered region at beginning of the +    -- array, the remaining shape vector and the corresponding remaining stride +    -- vector, return whether this all together covers a dense prefix of the +    -- array. If it does, return the number of elements in this prefix. +    checkCover :: Int -> [Int] -> [Int] -> Maybe Int +    checkCover block [] [] = Just block +    checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' +    checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +{-# NOINLINE vectorOp1 #-} +vectorOp1 :: forall a b. Storable a +          => (Ptr a -> Ptr b) +          -> (Int64 -> Ptr b -> Ptr b -> IO ()) +          -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do +  outv <- VSM.unsafeNew (VS.length v) +  VSM.unsafeWith outv $ \poutv -> +    VS.unsafeWith v $ \pv -> +      f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) +  VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} +vectorOp2 :: forall a b. Storable a +          => (a -> b) +          -> (Ptr a -> Ptr b) +          -> (a -> a -> a) +          -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- sv +          -> (Int64 -> Ptr b -> Ptr b -> b -> IO ())  -- vs +          -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ())  -- vv +          -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases +  (Left x) (Left y) -> VS.singleton (fss x y) + +  (Left x) (Right vy) -> +    unsafePerformIO $ do +      outv <- VSM.unsafeNew (VS.length vy) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith vy $ \pvy -> +          fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) +      VS.unsafeFreeze outv + +  (Right vx) (Left y) -> +    unsafePerformIO $ do +      outv <- VSM.unsafeNew (VS.length vx) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith vx $ \pvx -> +          fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) +      VS.unsafeFreeze outv + +  (Right vx) (Right vy) +    | VS.length vx == VS.length vy -> +        unsafePerformIO $ do +          outv <- VSM.unsafeNew (VS.length vx) +          VSM.unsafeWith outv $ \poutv -> +            VS.unsafeWith vx $ \pvx -> +              VS.unsafeWith vy $ \pvy -> +                fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) +          VS.unsafeFreeze outv +    | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) +                 => SNat n +                 -> (a -> b) +                 -> (Ptr a -> Ptr b) +                 -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) +  | null sh = error "unreachable" +  | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) +  | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) +  -- now the input array is nonempty +  | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) +  | last strides == 0 = +      liftVEltwise1 sn +        (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) +        (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) +  -- now there is useful work along the inner dimension +  | otherwise = +      let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those +          (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) +          ndimsF = length shF +      in unsafePerformIO $ do +           outv <- VSM.unsafeNew (product (init shF)) +           VSM.unsafeWith outv $ \poutv -> +             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> +                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> +                   fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) +           RS.fromVector (init sh) <$> VS.unsafeFreeze outv + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) +       ->  Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_binary_" ++ atCName arithtype +          c_ss = varE (aboNumOp arithop) +          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_fbinary_" ++ atCName arithtype +          c_ss = varE (afboNumOp arithop) +          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) +              => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) +              -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) +  | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) +              => (i -> i -> i)  -- ss +                 -- int32 +              -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- sv +              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ())  -- vs +              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- vv +                 -- int64 +              -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- sv +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv +              -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) +                => -- int32 +                   (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                   -- int64 +                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel +                -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 +  | otherwise = error "Unsupported Int width" + +class NumElt a where +  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a +  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + +instance NumElt Int32 where +  numEltAdd = addVectorInt32 +  numEltSub = subVectorInt32 +  numEltMul = mulVectorInt32 +  numEltNeg = negVectorInt32 +  numEltAbs = absVectorInt32 +  numEltSignum = signumVectorInt32 +  numEltSum1Inner = sum1VectorInt32 +  numEltProduct1Inner = product1VectorInt32 + +instance NumElt Int64 where +  numEltAdd = addVectorInt64 +  numEltSub = subVectorInt64 +  numEltMul = mulVectorInt64 +  numEltNeg = negVectorInt64 +  numEltAbs = absVectorInt64 +  numEltSignum = signumVectorInt64 +  numEltSum1Inner = sum1VectorInt64 +  numEltProduct1Inner = product1VectorInt64 + +instance NumElt Float where +  numEltAdd = addVectorFloat +  numEltSub = subVectorFloat +  numEltMul = mulVectorFloat +  numEltNeg = negVectorFloat +  numEltAbs = absVectorFloat +  numEltSignum = signumVectorFloat +  numEltSum1Inner = sum1VectorFloat +  numEltProduct1Inner = product1VectorFloat + +instance NumElt Double where +  numEltAdd = addVectorDouble +  numEltSub = subVectorDouble +  numEltMul = mulVectorDouble +  numEltNeg = negVectorDouble +  numEltAbs = absVectorDouble +  numEltSignum = signumVectorDouble +  numEltSum1Inner = sum1VectorDouble +  numEltProduct1Inner = product1VectorDouble + +instance NumElt Int where +  numEltAdd = intWidBranch2 @Int (+) +                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) +                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @Int (-) +                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) +                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @Int (*) +                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) +                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed @Int +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) +  numEltProduct1Inner = intWidBranchRed @Int +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +instance NumElt CInt where +  numEltAdd = intWidBranch2 @CInt (+) +                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) +                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @CInt (-) +                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) +                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @CInt (*) +                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) +                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed @CInt +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) +  numEltProduct1Inner = intWidBranchRed @CInt +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where +  floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where +  floatEltDiv = divVectorFloat +  floatEltPow = powVectorFloat +  floatEltLogbase = logbaseVectorFloat +  floatEltRecip = recipVectorFloat +  floatEltExp = expVectorFloat +  floatEltLog = logVectorFloat +  floatEltSqrt = sqrtVectorFloat +  floatEltSin = sinVectorFloat +  floatEltCos = cosVectorFloat +  floatEltTan = tanVectorFloat +  floatEltAsin = asinVectorFloat +  floatEltAcos = acosVectorFloat +  floatEltAtan = atanVectorFloat +  floatEltSinh = sinhVectorFloat +  floatEltCosh = coshVectorFloat +  floatEltTanh = tanhVectorFloat +  floatEltAsinh = asinhVectorFloat +  floatEltAcosh = acoshVectorFloat +  floatEltAtanh = atanhVectorFloat +  floatEltLog1p = log1pVectorFloat +  floatEltExpm1 = expm1VectorFloat +  floatEltLog1pexp = log1pexpVectorFloat +  floatEltLog1mexp = log1mexpVectorFloat + +instance FloatElt Double where +  floatEltDiv = divVectorDouble +  floatEltPow = powVectorDouble +  floatEltLogbase = logbaseVectorDouble +  floatEltRecip = recipVectorDouble +  floatEltExp = expVectorDouble +  floatEltLog = logVectorDouble +  floatEltSqrt = sqrtVectorDouble +  floatEltSin = sinVectorDouble +  floatEltCos = cosVectorDouble +  floatEltTan = tanVectorDouble +  floatEltAsin = asinVectorDouble +  floatEltAcos = acosVectorDouble +  floatEltAtan = atanVectorDouble +  floatEltSinh = sinhVectorDouble +  floatEltCosh = coshVectorDouble +  floatEltTanh = tanhVectorDouble +  floatEltAsinh = asinhVectorDouble +  floatEltAcosh = acoshVectorDouble +  floatEltAtanh = atanhVectorDouble +  floatEltLog1p = log1pVectorDouble +  floatEltExpm1 = expm1VectorDouble +  floatEltLog1pexp = log1pexpVectorDouble +  floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..6fc7229 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "binary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "fbinary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "unary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "funary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "reduce_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs new file mode 100644 index 0000000..a284bc1 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists.TH + + +data ArithType = ArithType +  { atType :: Name  -- ''Int32 +  , atCName :: String  -- "i32" +  } + +floatTypesList :: [ArithType] +floatTypesList = +  [ArithType ''Float "float" +  ,ArithType ''Double "double" +  ] + +typesList :: [ArithType] +typesList = +  [ArithType ''Int32 "i32" +  ,ArithType ''Int64 "i64" +  ] +  ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] +              ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] +              ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..8b7d05f --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Mixed.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | FBinop | Unop | FUnop | Redop +  deriving (Show, Eq) + +readArithLists :: OpKind +               -> (String -> Int -> String -> Q a) +               -> ([a] -> Q r) +               -> Q r +readArithLists targetkind fop fcombine = do +  addDependentFile "cbits/arith_lists.h" +  lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + +  mvals <- forM lns $ \line -> do +    if null (dropWhile (== ' ') line) +      then return Nothing +      else do let (kind, name, num, aux) = parseLine line +              if kind == targetkind +                then Just <$> fop name num aux +                else return Nothing + +  fcombine (catMaybes mvals) +  where +    parseLine s0 +      | ("LIST_", s1) <- splitAt 5 s0 +      , (kindstr, '(' : s2) <- break (== '(') s1 +      , (f1, ',' : s3) <- parseField s2 +      , (f2, ',' : s4) <- parseField s3 +      , (f3, ')' : _) <- parseField s4 +      , Just kind <- parseKind kindstr +      , let name = f1 +      , Just num <- readMaybe f2 +      , let aux = f3 +      = (kind, name, num, aux) +      | otherwise +      = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + +    parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + +    parseKind "BINOP" = Just Binop +    parseKind "FBINOP" = Just FBinop +    parseKind "UNOP" = Just Unop +    parseKind "FUNOP" = Just FUnop +    parseKind "REDOP" = Just Redop +    parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do +  cons <- readArithLists kind +            (\name _num _ -> return $ NormalC (mkName name) []) +            return +  return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do +  clauses <- readArithLists kind +               (\name _num _ -> return (Clause [ConP (mkName name) [] []] +                                               (NormalB (LitE (StringL (nametrans name)))) +                                               [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) +         ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do +  clauses <- readArithLists kind +               (\name num _ -> return (Clause [ConP (mkName name) [] []] +                                              (NormalB (LitE (IntegerL (fromIntegral num)))) +                                              [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) +         ,FunD (mkName funname) clauses] | 
