aboutsummaryrefslogtreecommitdiff
path: root/ops/Data/Array/Strided/Arith/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'ops/Data/Array/Strided/Arith/Internal.hs')
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs866
1 files changed, 866 insertions, 0 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
new file mode 100644
index 0000000..fe0fc4b
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -0,0 +1,866 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Strided.Arith.Internal where
+
+import Control.Monad
+import Data.Bifunctor (second)
+import Data.Bits
+import Data.Int
+import Data.List (sort)
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.Ptr
+import Foreign.Storable
+import qualified GHC.TypeNats as TypeNats
+import GHC.TypeLits
+import Language.Haskell.TH
+import System.IO (hFlush, stdout)
+import System.IO.Unsafe
+
+import Data.Array.Strided.Array
+import Data.Array.Strided.Arith.Internal.Lists
+import Data.Array.Strided.Arith.Internal.Foreign
+
+
+-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
+
+
+-- TODO: move this to a utilities module
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
+data Dict c where
+ Dict :: c => Dict c
+
+debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String
+debugShow (Array sh strides offset vec) =
+ "Array @" ++ (show (natVal (Proxy @n))) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">"
+
+
+-- TODO: test all the cases of this thing with various input strides
+liftOpEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a -> Array n a
+liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec)
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ if blockSz == 0
+ then Array sh (map (const 0) strides) 0 VS.empty
+ else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [fromIntegral blockSz] [1] blockOff vec)
+ in Array sh strides (offset - blockOff) resvec
+ | otherwise = wrapUnary sn ptrconv cf_strided arr
+
+-- TODO: test all the cases of this thing with various input strides
+liftOpEltwise2 :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv
+ -> Array n a -> Array n a -> Array n a
+liftOpEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
+ arr1@(Array sh1 strides1 offset1 vec1)
+ arr2@(Array sh2 strides2 offset2 vec2)
+ | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | any (<= 0) sh1 = Array sh1 (0 <$ strides1) 0 VS.empty
+ | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
+ (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))
+ in Array sh1 strides1 0 vec'
+
+ (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
+ let arr2' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec2)
+ resvec = arrValues $ wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
+ in Array sh1 strides2 (offset2 - blockOff) resvec
+
+ (Just (_, 1), Nothing) -> -- scalar * array
+ wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2
+
+ (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
+ let arr1' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec1)
+ resvec = arrValues $ wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
+ in Array sh1 strides1 (offset1 - blockOff) resvec
+
+ (Nothing, Just (_, 1)) -> -- array * scalar
+ wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2)
+
+ (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
+ | strides1 == strides2
+ -> -- dense * dense but the strides match
+ if blockSz1 /= blockSz2 || offset1 - blockOff1 /= offset2 - blockOff2
+ then error $ "Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " ++ show (strides1, (blockOff1, blockSz1), strides2, (blockOff2, blockSz2))
+ else
+ let arr1' = arrayFromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
+ arr2' = arrayFromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
+ resvec = arrValues $ wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2'
+ in Array sh1 strides1 (offset1 - blockOff1) resvec
+
+ (_, _) -> -- fallback case
+ wrapBinaryVV sn ptrconv f_vv arr1 arr2
+
+-- | Given shape vector, offset and stride vector, check whether this virtual
+-- vector uses a dense subarray of its backing array. If so, the first index
+-- and the number of elements in this subarray is returned.
+-- This excludes any offset.
+stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
+stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0)
+stridesDense sh offsetNeg stridesNeg =
+ -- First reverse all dimensions with negative stride, so that the first used
+ -- value is at 'offset' and the rest is >= offset.
+ let (offset, strides) = flipReverseds sh offsetNeg stridesNeg
+ in -- sort dimensions on their stride, ascending, dropping any zero strides
+ case filter ((/= 0) . fst) (sort (zip strides sh)) of
+ [] -> Just (offset, 1)
+ (1, n) : pairs -> (offset,) <$> checkCover n pairs
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
+ where
+ -- Given size of currently densely covered region at beginning of the
+ -- array and the remaining (stride, size) pairs with all strides >=1,
+ -- return whether this all together covers a dense prefix of the array. If
+ -- it does, return the number of elements in this prefix.
+ checkCover :: Int -> [(Int, Int)] -> Maybe Int
+ checkCover block [] = Just block
+ checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs
+
+ -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
+ flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
+ flipReverseds [] off [] = (off, [])
+ flipReverseds (n : sh') off (s : str')
+ | s >= 0 = second (s :) (flipReverseds sh' off str')
+ | otherwise =
+ let off' = off + (n - 1) * s
+ in second ((-s) :) (flipReverseds sh' off' str')
+ flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
+
+data Unreplicated a =
+ forall n'. KnownNat n' =>
+ -- | Let the original array, with replicated dimensions, be called A.
+ Unreplicated -- | An array with all strides /= 0. Call this array U. It has
+ -- the same shape as A, except with all the replicated (stride
+ -- == 0) dimensions removed. The shape of U is the
+ -- "unreplicated shape".
+ (Array n' a)
+ -- | Product of sizes of the unreplicated dimensions
+ Int
+ -- | Given the stride vector of an array with the unreplicated
+ -- shape, this function reinserts zeros so that it may be
+ -- combined with the original shape of A.
+ ([Int] -> [Int])
+
+-- | Removes all replicated dimensions (i.e. those with stride == 0) from the array.
+unreplicateStrides :: Array n a -> Unreplicated a
+unreplicateStrides (Array sh strides offset vec) =
+ let replDims = map (== 0) strides
+ (shF, stridesF) = unzip [(n, s) | (n, s) <- zip sh strides, s /= 0]
+
+ reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
+ reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
+ reinsertZeros [] [] = []
+ reinsertZeros (False : _) [] = error $ "unreplicateStrides: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error $ "unreplicateStrides: Internal error: reply strides too long"
+
+ unrepSize = product [n | (n, True) <- zip sh replDims]
+
+ in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
+
+simplifyArray :: Array n a
+ -> (forall n'. KnownNat n'
+ => Array n' a -- U
+ -- Product of sizes of the unreplicated dimensions
+ -> Int
+ -- Convert index in U back to index into original
+ -- array. Replicated dimensions get 0.
+ -> ([Int] -> [Int])
+ -- Given a new array of the same shape as U, convert
+ -- it back to the original shape and iteration order.
+ -> (Array n' a -> Array n a)
+ -- Do the same except without the INNER dimension.
+ -- This throws an error if the inner dimension had
+ -- stride 0.
+ -> (Array (n' - 1) a -> Array (n - 1) a)
+ -> r)
+ -> r
+simplifyArray array k
+ | let revDims = map (<0) (arrStrides array)
+ , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
+ = k array'
+ unrepSize
+ (\idx -> rereplicate (zipWith3 (\b n i -> if b then n - 1 - i else i)
+ revDims (arrShape array') idx))
+ (\(Array sh' strides' offset' vec') ->
+ if sh' == arrShape array'
+ then arrayRevDims revDims (Array (arrShape array) (rereplicate strides') offset' vec')
+ else error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")")
+ (\(Array sh' strides' offset' vec') ->
+ if | sh' /= init (arrShape array') ->
+ error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")"
+ | last (arrStrides array) == 0 ->
+ error $ "simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
+ | otherwise ->
+ arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec'))
+
+{-# NOINLINE wrapUnary #-}
+wrapUnary :: forall a b n. Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a
+ -> Array n a
+wrapUnary _ ptrconv cf_strided array =
+ simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do
+ let ndims' = length sh
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral ndims') (ptrconv poutv) psh pstrides pv'
+ restore . arrayFromVector sh <$> VS.unsafeFreeze outv
+
+{-# NOINLINE wrapBinarySV #-}
+wrapBinarySV :: forall a b n. Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
+ -> a -> Array n a
+ -> Array n a
+wrapBinarySV SNat valconv ptrconv cf_strided x array =
+ simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do
+ let ndims' = length sh
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral ndims') psh (ptrconv poutv) (valconv x) pstrides pv'
+ restore . arrayFromVector sh <$> VS.unsafeFreeze outv
+
+wrapBinaryVS :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
+ -> Array n a -> a
+ -> Array n a
+wrapBinaryVS sn valconv ptrconv cf_strided arr y =
+ wrapBinarySV sn valconv ptrconv
+ (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr
+
+-- | The two shapes must be equal and non-empty. This is checked.
+{-# NOINLINE wrapBinaryVV #-}
+wrapBinaryVV :: forall a b n. Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a -> Array n a
+ -> Array n a
+-- TODO: do unreversing and unreplication on the input arrays (but
+-- simultaneously: can only unreplicate if _both_ are replicated on that
+-- dimension)
+wrapBinaryVV sn@SNat ptrconv cf_strided
+ (Array sh strides1 offset1 vec1)
+ (Array sh2 strides2 offset2 vec2)
+ | sh /= sh2 = error $ "wrapBinaryVV: unequal shapes: " ++ show sh ++ " and " ++ show sh2
+ | any (<= 0) sh = error $ "wrapBinaryVV: empty shape: " ++ show sh
+ | otherwise = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith vec1 $ \pv1 ->
+ VS.unsafeWith vec2 $ \pv2 ->
+ let pv1' = pv1 `plusPtr` (offset1 * sizeOf (undefined :: a))
+ pv2' = pv2 `plusPtr` (offset2 * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 pv1' pstrides2 pv2'
+ arrayFromVector sh <$> VS.unsafeFreeze outv
+
+-- TODO: test handling of negative strides
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> Array (n + 1) a -> Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides offset vec)
+ | null sh = error "unreachable"
+ | last sh <= 0 = arrayFromConstant (init sh) 0
+ | any (<= 0) (init sh) = Array (init sh) (0 <$ init strides) 0 VS.empty
+ -- now the input array is nonempty
+ | last sh == 1 = Array (init sh) (init strides) offset vec
+ | last strides == 0 =
+ wrapBinarySV sn valconv ptrconv fscale (fromIntegral @Int @a (last sh))
+ (Array (init sh) (init strides) offset vec)
+ -- now there is useful work along the inner dimension
+ -- Note that unreplication keeps the inner dimension intact, because `last strides /= 0` at this point.
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec' :: Array n' a) _ _ _ restore -> unsafePerformIO $ do
+ let ndims' = length sh'
+ outv <- VSM.unsafeNew (product (init sh'))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
+ TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
+ LTI -> pure Dict
+ EQI -> pure Dict
+ _ -> error "impossible" -- because `last strides /= 0`
+ case sameNat (natSing @(n' - 1)) (natSing @n'm1) of
+ Just Refl -> restore . arrayFromVector @_ @n'm1 (init sh') <$> VS.unsafeFreeze outv
+ Nothing -> error "impossible"
+
+-- TODO: test handling of negative strides
+-- | Reduce full array
+{-# NOINLINE vectorRedFullOp #-}
+vectorRedFullOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> Int -> a)
+ -> (b -> a)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> Array n a -> a
+vectorRedFullOp _ scaleval valbackconv ptrconv fred array@(Array sh strides offset vec)
+ | null sh = vec VS.! offset -- 0D array has one element
+ | any (<= 0) sh = 0
+ -- now the input array is nonempty
+ | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec') unrepSize _ _ _ -> unsafePerformIO $ do
+ let ndims' = length sh'
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in (`scaleval` unrepSize) . valbackconv
+ <$> fred (fromIntegral ndims') psh pstrides (ptrconv pv')
+
+-- TODO: test this function
+-- | Find extremum (minindex ("argmin") or maxindex) in full array
+{-# NOINLINE vectorExtremumOp #-}
+vectorExtremumOp :: forall a b n. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> Array n a -> [Int] -- result length: n
+vectorExtremumOp ptrconv fextrem array@(Array sh strides _ _)
+ | null sh = []
+ | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
+ -- now the input array is nonempty
+ | all (== 0) strides = 0 <$ sh
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec') _ upindex _ _ -> unsafePerformIO $ do
+ let ndims' = length sh'
+ outvR <- VSM.unsafeNew (length sh')
+ VSM.unsafeWith outvR $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in fextrem poutv (fromIntegral ndims') psh pstrides (ptrconv pv')
+ upindex . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outvR
+
+{-# NOINLINE vectorDotprodInnerOp #-}
+vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (SNat n -> Array n a -> Array n a -> Array n a) -- ^ elementwise multiplication
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
+ arr1@(Array sh1 strides1 offset1 vec1)
+ arr2@(Array sh2 strides2 offset2 vec2)
+ | null sh1 || null sh2 = error "unreachable"
+ | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | last sh1 <= 0 = arrayFromConstant (init sh1) 0
+ | any (<= 0) (init sh1) = Array (init sh1) (0 <$ init strides1) 0 VS.empty
+ -- now the input arrays are nonempty
+ | last sh1 == 1 =
+ fmul sn (Array (init sh1) (init strides1) offset1 vec1)
+ (Array (init sh2) (init strides2) offset2 vec2)
+ | last strides1 == 0 =
+ fmul sn
+ (Array (init sh1) (init strides1) offset1 vec1)
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr2)
+ | last strides2 == 0 =
+ fmul sn
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
+ (Array (init sh2) (init strides2) offset2 vec2)
+ -- now there is useful dotprod work along the inner dimension
+ | otherwise = unsafePerformIO $ do
+ let inrank = fromSNat' sn + 1
+ outv <- VSM.unsafeNew (product (init sh1))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
+ pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
+ pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
+ arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
+
+mulWithInt :: Num a => a -> Int -> a
+mulWithInt a i = a * fromIntegral i
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss_str = varE (aboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM intTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_ibinary_" ++ atCName arithtype
+ c_ss_str = varE (aiboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss_str = varE (afboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let scaleVar = case arithop of
+ RO_SUM -> varE 'mulWithInt
+ RO_PRODUCT -> varE '(^)
+ let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
+ namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
+ c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ sequence [SigD name1 <$>
+ [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
+ return $ FunD name1 [Clause [] (NormalB body) []]
+ ,SigD namefull <$>
+ [t| forall n. SNat n -> Array n $ttyp -> $ttyp |]
+ ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
+ return $ FunD namefull [Clause [] (NormalB body) []]
+ ])
+
+$(fmap concat . forM typesList $ \arithtype ->
+ fmap concat . forM ["min", "max"] $ \fname -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
+ sequence [SigD name <$>
+ [t| forall n. Array n $ttyp -> [Int] |]
+ ,do body <- [| vectorExtremumOp id $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
+ mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
+foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()
+
+statisticsEnable :: Bool -> IO ()
+statisticsEnable b = c_stats_enable (if b then 1 else 0)
+
+-- | Consumes the log: one particular event will only ever be printed once,
+-- even if statisticsPrintAll is called multiple times.
+statisticsPrintAll :: IO ()
+statisticsPrintAll = do
+ hFlush stdout -- lower the chance of overlapping output
+ c_stats_print_all
+
+-- This branch is ostensibly a runtime branch, but will (hopefully) be
+-- constant-folded away by GHC.
+intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
+ => (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> (SNat n -> Array n i -> Array n i)
+intWidBranch1 f32 f64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr f32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr f64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> i -> i) -- ss
+ -- int32
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv
+ -> (SNat n -> Array n i -> Array n i -> Array n i)
+intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> Array (n + 1) i -> Array n i)
+intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> Int -> i) -- ^ scale op
+ -- int32
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> (SNat n -> Array n i -> i)
+intWidBranchRedFull fsc fred32 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> (Array n i -> [Int])
+intWidBranchExtr fextr32 fextr64
+ | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
+ | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
+ => -- int32
+ (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i)
+intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
+ | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
+ | otherwise = error "Unsupported Int width"
+
+class NumElt a where
+ numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltSub :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltMul :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltNeg :: SNat n -> Array n a -> Array n a
+ numEltAbs :: SNat n -> Array n a -> Array n a
+ numEltSignum :: SNat n -> Array n a -> Array n a
+ numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a
+ numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a
+ numEltSumFull :: SNat n -> Array n a -> a
+ numEltProductFull :: SNat n -> Array n a -> a
+ numEltMinIndex :: SNat n -> Array n a -> [Int]
+ numEltMaxIndex :: SNat n -> Array n a -> [Int]
+ numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+
+instance NumElt Int32 where
+ numEltAdd = addVectorInt32
+ numEltSub = subVectorInt32
+ numEltMul = mulVectorInt32
+ numEltNeg = negVectorInt32
+ numEltAbs = absVectorInt32
+ numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
+ numEltSumFull = sumFullVectorInt32
+ numEltProductFull = productFullVectorInt32
+ numEltMinIndex _ = minindexVectorInt32
+ numEltMaxIndex _ = maxindexVectorInt32
+ numEltDotprodInner = dotprodinnerVectorInt32
+
+instance NumElt Int64 where
+ numEltAdd = addVectorInt64
+ numEltSub = subVectorInt64
+ numEltMul = mulVectorInt64
+ numEltNeg = negVectorInt64
+ numEltAbs = absVectorInt64
+ numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
+ numEltSumFull = sumFullVectorInt64
+ numEltProductFull = productFullVectorInt64
+ numEltMinIndex _ = minindexVectorInt64
+ numEltMaxIndex _ = maxindexVectorInt64
+ numEltDotprodInner = dotprodinnerVectorInt64
+
+instance NumElt Float where
+ numEltAdd = addVectorFloat
+ numEltSub = subVectorFloat
+ numEltMul = mulVectorFloat
+ numEltNeg = negVectorFloat
+ numEltAbs = absVectorFloat
+ numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
+ numEltSumFull = sumFullVectorFloat
+ numEltProductFull = productFullVectorFloat
+ numEltMinIndex _ = minindexVectorFloat
+ numEltMaxIndex _ = maxindexVectorFloat
+ numEltDotprodInner = dotprodinnerVectorFloat
+
+instance NumElt Double where
+ numEltAdd = addVectorDouble
+ numEltSub = subVectorDouble
+ numEltMul = mulVectorDouble
+ numEltNeg = negVectorDouble
+ numEltAbs = absVectorDouble
+ numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
+ numEltSumFull = sumFullVectorDouble
+ numEltProductFull = productFullVectorDouble
+ numEltMinIndex _ = minindexVectorDouble
+ numEltMaxIndex _ = maxindexVectorDouble
+ numEltDotprodInner = dotprodinnerVectorDouble
+
+instance NumElt Int where
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
+ numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+
+instance NumElt CInt where
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
+ numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+
+class NumElt a => IntElt a where
+ intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
+ intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+
+instance IntElt Int32 where
+ intEltQuot = quotVectorInt32
+ intEltRem = remVectorInt32
+
+instance IntElt Int64 where
+ intEltQuot = quotVectorInt64
+ intEltRem = remVectorInt64
+
+instance IntElt Int where
+ intEltQuot = intWidBranch2 @Int quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @Int rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+instance IntElt CInt where
+ intEltQuot = intWidBranch2 @CInt quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @CInt rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+class NumElt a => FloatElt a where
+ floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltRecip :: SNat n -> Array n a -> Array n a
+ floatEltExp :: SNat n -> Array n a -> Array n a
+ floatEltLog :: SNat n -> Array n a -> Array n a
+ floatEltSqrt :: SNat n -> Array n a -> Array n a
+ floatEltSin :: SNat n -> Array n a -> Array n a
+ floatEltCos :: SNat n -> Array n a -> Array n a
+ floatEltTan :: SNat n -> Array n a -> Array n a
+ floatEltAsin :: SNat n -> Array n a -> Array n a
+ floatEltAcos :: SNat n -> Array n a -> Array n a
+ floatEltAtan :: SNat n -> Array n a -> Array n a
+ floatEltSinh :: SNat n -> Array n a -> Array n a
+ floatEltCosh :: SNat n -> Array n a -> Array n a
+ floatEltTanh :: SNat n -> Array n a -> Array n a
+ floatEltAsinh :: SNat n -> Array n a -> Array n a
+ floatEltAcosh :: SNat n -> Array n a -> Array n a
+ floatEltAtanh :: SNat n -> Array n a -> Array n a
+ floatEltLog1p :: SNat n -> Array n a -> Array n a
+ floatEltExpm1 :: SNat n -> Array n a -> Array n a
+ floatEltLog1pexp :: SNat n -> Array n a -> Array n a
+ floatEltLog1mexp :: SNat n -> Array n a -> Array n a
+ floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltPow = powVectorFloat
+ floatEltLogbase = logbaseVectorFloat
+ floatEltRecip = recipVectorFloat
+ floatEltExp = expVectorFloat
+ floatEltLog = logVectorFloat
+ floatEltSqrt = sqrtVectorFloat
+ floatEltSin = sinVectorFloat
+ floatEltCos = cosVectorFloat
+ floatEltTan = tanVectorFloat
+ floatEltAsin = asinVectorFloat
+ floatEltAcos = acosVectorFloat
+ floatEltAtan = atanVectorFloat
+ floatEltSinh = sinhVectorFloat
+ floatEltCosh = coshVectorFloat
+ floatEltTanh = tanhVectorFloat
+ floatEltAsinh = asinhVectorFloat
+ floatEltAcosh = acoshVectorFloat
+ floatEltAtanh = atanhVectorFloat
+ floatEltLog1p = log1pVectorFloat
+ floatEltExpm1 = expm1VectorFloat
+ floatEltLog1pexp = log1pexpVectorFloat
+ floatEltLog1mexp = log1mexpVectorFloat
+ floatEltAtan2 = atan2VectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltPow = powVectorDouble
+ floatEltLogbase = logbaseVectorDouble
+ floatEltRecip = recipVectorDouble
+ floatEltExp = expVectorDouble
+ floatEltLog = logVectorDouble
+ floatEltSqrt = sqrtVectorDouble
+ floatEltSin = sinVectorDouble
+ floatEltCos = cosVectorDouble
+ floatEltTan = tanVectorDouble
+ floatEltAsin = asinVectorDouble
+ floatEltAcos = acosVectorDouble
+ floatEltAtan = atanVectorDouble
+ floatEltSinh = sinhVectorDouble
+ floatEltCosh = coshVectorDouble
+ floatEltTanh = tanhVectorDouble
+ floatEltAsinh = asinhVectorDouble
+ floatEltAcosh = acoshVectorDouble
+ floatEltAtanh = atanhVectorDouble
+ floatEltLog1p = log1pVectorDouble
+ floatEltExpm1 = expm1VectorDouble
+ floatEltLog1pexp = log1pexpVectorDouble
+ floatEltLog1mexp = log1mexpVectorDouble
+ floatEltAtan2 = atan2VectorDouble