aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
commita65306ba5d80891b20ac86fa3a3242f9497751e6 (patch)
tree834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed/Internal
parentd8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff)
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs435
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs55
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists.hs78
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs82
4 files changed, 650 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
new file mode 100644
index 0000000..cf6820b
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -0,0 +1,435 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Internal.Arith where
+
+import Control.Monad (forM, guard)
+import qualified Data.Array.Internal as OI
+import qualified Data.Array.Internal.RankedG as RG
+import qualified Data.Array.Internal.RankedS as RS
+import Data.Bits
+import Data.Int
+import Data.List (sort)
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.Ptr
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+import Language.Haskell.TH
+import System.IO.Unsafe
+
+import Data.Array.Mixed.Internal.Arith.Foreign
+import Data.Array.Mixed.Internal.Arith.Lists
+
+
+liftVEltwise1 :: Storable a
+ => SNat n
+ -> (VS.Vector a -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just prefixSz <- stridesDense sh strides =
+ let vec' = f (VS.slice offset prefixSz vec)
+ in RS.A (RG.A sh (OI.T strides 0 vec'))
+ | otherwise = RS.fromVector sh (f (RS.toVector arr))
+
+liftVEltwise2 :: Storable a
+ => SNat n
+ -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a -> RS.Array n a
+liftVEltwise2 SNat f
+ arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
+ arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
+ | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
+ | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
+ (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
+ in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
+ (Just 1, Just n) -> -- scalar * dense
+ RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))
+ (Just n, Just 1) -> -- dense * scalar
+ RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))
+ (_, _) -> -- fallback case
+ RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
+
+-- | Given the shape vector and the stride vector, return whether this vector
+-- of strides uses a dense prefix of its backing array. If so, the number of
+-- elements in this prefix is returned.
+-- This excludes any offset.
+stridesDense :: [Int] -> [Int] -> Maybe Int
+stridesDense sh _ | any (<= 0) sh = Just 0
+stridesDense sh str =
+ -- sort dimensions on their stride, ascending, dropping any zero strides
+ case dropWhile ((== 0) . fst) (sort (zip str sh)) of
+ [] -> Just 1
+ (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
+ where
+ -- Given size of currently densely covered region at beginning of the
+ -- array, the remaining shape vector and the corresponding remaining stride
+ -- vector, return whether this all together covers a dense prefix of the
+ -- array. If it does, return the number of elements in this prefix.
+ checkCover :: Int -> [Int] -> [Int] -> Maybe Int
+ checkCover block [] [] = Just block
+ checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
+ checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
+
+{-# NOINLINE vectorOp1 #-}
+vectorOp1 :: forall a b. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr b -> IO ())
+ -> VS.Vector a -> VS.Vector a
+vectorOp1 ptrconv f v = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length v)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith v $ \pv ->
+ f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv)
+ VS.unsafeFreeze outv
+
+-- | If two vectors are given, assumes that they have the same length.
+{-# NOINLINE vectorOp2 #-}
+vectorOp2 :: forall a b. Storable a
+ => (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv
+ -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs
+ -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv
+ -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a
+vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
+ (Left x) (Left y) -> VS.singleton (fss x y)
+
+ (Left x) (Right vy) ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vy)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vy $ \pvy ->
+ fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy)
+ VS.unsafeFreeze outv
+
+ (Right vx) (Left y) ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vx)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vx $ \pvx ->
+ fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y)
+ VS.unsafeFreeze outv
+
+ (Right vx) (Right vy)
+ | VS.length vx == VS.length vy ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vx)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vx $ \pvx ->
+ VS.unsafeWith vy $ \pvy ->
+ fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy)
+ VS.unsafeFreeze outv
+ | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
+
+-- TODO: test all the weird cases of this function
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> RS.Array (n + 1) a -> RS.Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
+ | null sh = error "unreachable"
+ | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0])
+ | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty))
+ -- now the input array is nonempty
+ | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
+ | last strides == 0 =
+ liftVEltwise1 sn
+ (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
+ (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
+ -- now there is useful work along the inner dimension
+ | otherwise =
+ let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
+ (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
+ ndimsF = length shF
+ in unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product (init shF))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
+ fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
+ RS.fromVector (init sh) <$> VS.unsafeFreeze outv
+
+flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
+ -> Int64 -> Ptr a -> Ptr a -> a -> IO ()
+flipOp f n out v s = f n out s v
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss = varE (aboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss = varE (afboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+-- This branch is ostensibly a runtime branch, but will (hopefully) be
+-- constant-folded away by GHC.
+intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
+ => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
+ -> (SNat n -> RS.Array n i -> RS.Array n i)
+intWidBranch1 f32 f64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> i -> i) -- ss
+ -- int32
+ -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
+ -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
+intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
+intWidBranchRed fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
+ | otherwise = error "Unsupported Int width"
+
+class NumElt a where
+ numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+
+instance NumElt Int32 where
+ numEltAdd = addVectorInt32
+ numEltSub = subVectorInt32
+ numEltMul = mulVectorInt32
+ numEltNeg = negVectorInt32
+ numEltAbs = absVectorInt32
+ numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
+
+instance NumElt Int64 where
+ numEltAdd = addVectorInt64
+ numEltSub = subVectorInt64
+ numEltMul = mulVectorInt64
+ numEltNeg = negVectorInt64
+ numEltAbs = absVectorInt64
+ numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
+
+instance NumElt Float where
+ numEltAdd = addVectorFloat
+ numEltSub = subVectorFloat
+ numEltMul = mulVectorFloat
+ numEltNeg = negVectorFloat
+ numEltAbs = absVectorFloat
+ numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
+
+instance NumElt Double where
+ numEltAdd = addVectorDouble
+ numEltSub = subVectorDouble
+ numEltMul = mulVectorDouble
+ numEltNeg = negVectorDouble
+ numEltAbs = absVectorDouble
+ numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
+
+instance NumElt Int where
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+instance NumElt CInt where
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+class FloatElt a where
+ floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltPow = powVectorFloat
+ floatEltLogbase = logbaseVectorFloat
+ floatEltRecip = recipVectorFloat
+ floatEltExp = expVectorFloat
+ floatEltLog = logVectorFloat
+ floatEltSqrt = sqrtVectorFloat
+ floatEltSin = sinVectorFloat
+ floatEltCos = cosVectorFloat
+ floatEltTan = tanVectorFloat
+ floatEltAsin = asinVectorFloat
+ floatEltAcos = acosVectorFloat
+ floatEltAtan = atanVectorFloat
+ floatEltSinh = sinhVectorFloat
+ floatEltCosh = coshVectorFloat
+ floatEltTanh = tanhVectorFloat
+ floatEltAsinh = asinhVectorFloat
+ floatEltAcosh = acoshVectorFloat
+ floatEltAtanh = atanhVectorFloat
+ floatEltLog1p = log1pVectorFloat
+ floatEltExpm1 = expm1VectorFloat
+ floatEltLog1pexp = log1pexpVectorFloat
+ floatEltLog1mexp = log1mexpVectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltPow = powVectorDouble
+ floatEltLogbase = logbaseVectorDouble
+ floatEltRecip = recipVectorDouble
+ floatEltExp = expVectorDouble
+ floatEltLog = logVectorDouble
+ floatEltSqrt = sqrtVectorDouble
+ floatEltSin = sinVectorDouble
+ floatEltCos = cosVectorDouble
+ floatEltTan = tanVectorDouble
+ floatEltAsin = asinVectorDouble
+ floatEltAcos = acosVectorDouble
+ floatEltAtan = atanVectorDouble
+ floatEltSinh = sinhVectorDouble
+ floatEltCosh = coshVectorDouble
+ floatEltTanh = tanhVectorDouble
+ floatEltAsinh = asinhVectorDouble
+ floatEltAcosh = acoshVectorDouble
+ floatEltAtanh = atanhVectorDouble
+ floatEltLog1p = log1pVectorDouble
+ floatEltExpm1 = expm1VectorDouble
+ floatEltLog1pexp = log1pexpVectorDouble
+ floatEltLog1mexp = log1mexpVectorDouble
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
new file mode 100644
index 0000000..6fc7229
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -0,0 +1,55 @@
+{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Mixed.Internal.Arith.Foreign where
+
+import Control.Monad
+import Data.Int
+import Data.Maybe
+import Foreign.C.Types
+import Foreign.Ptr
+import Language.Haskell.TH
+
+import Data.Array.Mixed.Internal.Arith.Lists
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "binary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "fbinary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "unary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "funary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "reduce_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
new file mode 100644
index 0000000..a284bc1
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Mixed.Internal.Arith.Lists where
+
+import Data.Char
+import Data.Int
+import Language.Haskell.TH
+
+import Data.Array.Mixed.Internal.Arith.Lists.TH
+
+
+data ArithType = ArithType
+ { atType :: Name -- ''Int32
+ , atCName :: String -- "i32"
+ }
+
+floatTypesList :: [ArithType]
+floatTypesList =
+ [ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
+typesList :: [ArithType]
+typesList =
+ [ArithType ''Int32 "i32"
+ ,ArithType ''Int64 "i64"
+ ]
+ ++ floatTypesList
+
+-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
+$(genArithDataType Binop "ArithBOp")
+
+$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
+$(genArithEnumFun Binop ''ArithBOp "aboEnum")
+
+$(do clauses <- readArithLists Binop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
+ ,return $ FunD (mkName "aboNumOp") clauses])
+
+
+-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
+$(genArithDataType FBinop "ArithFBOp")
+
+$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
+$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
+
+$(do clauses <- readArithLists FBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
+ ,return $ FunD (mkName "afboNumOp") clauses])
+
+
+-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
+$(genArithDataType Unop "ArithUOp")
+
+$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
+$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+
+
+-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
+$(genArithDataType FUnop "ArithFUOp")
+
+$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
+$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
+
+
+-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
+$(genArithDataType Redop "ArithRedOp")
+
+$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
+$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
new file mode 100644
index 0000000..8b7d05f
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
@@ -0,0 +1,82 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Mixed.Internal.Arith.Lists.TH where
+
+import Control.Monad
+import Control.Monad.IO.Class
+import Data.Maybe
+import Foreign.C.Types
+import Language.Haskell.TH
+import Language.Haskell.TH.Syntax
+import Text.Read
+
+
+data OpKind = Binop | FBinop | Unop | FUnop | Redop
+ deriving (Show, Eq)
+
+readArithLists :: OpKind
+ -> (String -> Int -> String -> Q a)
+ -> ([a] -> Q r)
+ -> Q r
+readArithLists targetkind fop fcombine = do
+ addDependentFile "cbits/arith_lists.h"
+ lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
+
+ mvals <- forM lns $ \line -> do
+ if null (dropWhile (== ' ') line)
+ then return Nothing
+ else do let (kind, name, num, aux) = parseLine line
+ if kind == targetkind
+ then Just <$> fop name num aux
+ else return Nothing
+
+ fcombine (catMaybes mvals)
+ where
+ parseLine s0
+ | ("LIST_", s1) <- splitAt 5 s0
+ , (kindstr, '(' : s2) <- break (== '(') s1
+ , (f1, ',' : s3) <- parseField s2
+ , (f2, ',' : s4) <- parseField s3
+ , (f3, ')' : _) <- parseField s4
+ , Just kind <- parseKind kindstr
+ , let name = f1
+ , Just num <- readMaybe f2
+ , let aux = f3
+ = (kind, name, num, aux)
+ | otherwise
+ = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
+
+ parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
+
+ parseKind "BINOP" = Just Binop
+ parseKind "FBINOP" = Just FBinop
+ parseKind "UNOP" = Just Unop
+ parseKind "FUNOP" = Just FUnop
+ parseKind "REDOP" = Just Redop
+ parseKind _ = Nothing
+
+genArithDataType :: OpKind -> String -> Q [Dec]
+genArithDataType kind dtname = do
+ cons <- readArithLists kind
+ (\name _num _ -> return $ NormalC (mkName name) [])
+ return
+ return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
+
+genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
+genArithNameFun kind dtname funname nametrans = do
+ clauses <- readArithLists kind
+ (\name _num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (StringL (nametrans name))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
+ ,FunD (mkName funname) clauses]
+
+genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
+genArithEnumFun kind dtname funname = do
+ clauses <- readArithLists kind
+ (\name num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (IntegerL (fromIntegral num))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
+ ,FunD (mkName funname) clauses]