aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal/Arith.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs928
1 files changed, 11 insertions, 917 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 27ebb64..f7a76bc 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -1,929 +1,23 @@
-{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TemplateHaskell #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Internal.Arith where
-import Control.Monad (forM, guard)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
-import Data.Bifunctor (second)
-import Data.Bits
-import Data.Int
-import Data.List (sort)
-import Data.Vector.Storable qualified as VS
-import Data.Vector.Storable.Mutable qualified as VSM
-import Foreign.C.Types
-import Foreign.Marshal.Alloc (alloca)
-import Foreign.Ptr
-import Foreign.Storable (Storable(sizeOf), peek, poke)
-import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
-import Language.Haskell.TH
-import System.IO (hFlush, stdout)
-import System.IO.Unsafe
-import Data.Array.Mixed.Internal.Arith.Foreign
-import Data.Array.Mixed.Internal.Arith.Lists
-import Data.Array.Mixed.Types (fromSNat')
+import Data.Array.Strided qualified as AS
--- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
+fromO :: RS.Array n a -> AS.Array n a
+fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec
+toO :: AS.Array n a -> RS.Array n a
+toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
--- TODO: test all the cases of this thing with various input strides
-liftVEltwise1 :: (Storable a, Storable b)
- => SNat n
- -> (VS.Vector a -> VS.Vector b)
- -> RS.Array n a -> RS.Array n b
-liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
- | Just (blockOff, blockSz) <- stridesDense sh offset strides =
- let vec' = f (VS.slice blockOff blockSz vec)
- in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
- | otherwise = RS.fromVector sh (f (RS.toVector arr))
+liftO1 :: (AS.Array n a -> AS.Array n' b)
+ -> RS.Array n a -> RS.Array n' b
+liftO1 f = toO . f . fromO
--- TODO: test all the cases of this thing with various input strides
-{-# NOINLINE liftOpEltwise1 #-}
-liftOpEltwise1 :: (Storable a, Storable b)
- => SNat n
- -> (Ptr a -> Ptr a')
- -> (Ptr b -> Ptr b')
- -> (Int64 -> Ptr b' -> Ptr Int64 -> Ptr Int64 -> Ptr a' -> IO ())
- -> RS.Array n a -> RS.Array n b
-liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
- -- TODO: less code duplication between these two branches
- | Just (blockOff, blockSz) <- stridesDense sh offset strides =
- if blockSz == 0
- then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty))
- else unsafePerformIO $ do
- outv <- VSM.unsafeNew blockSz
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->
- VS.unsafeWith (VS.singleton 1) $ \pstrides ->
- VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv ->
- cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
- RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv
- | otherwise = unsafePerformIO $ do
- outv <- VSM.unsafeNew (product sh)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->
- cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
- RS.fromVector sh <$> VS.unsafeFreeze outv
-
--- TODO: test all the cases of this thing with various input strides
-liftVEltwise2 :: Storable a
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (a -> a -> a)
- -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
-liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
- arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
- arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))
- | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
- (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
- let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))
- in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
-
- (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
- let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
- in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec))
-
- (Just (_, 1), Nothing) -> -- scalar * array
- wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2
-
- (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
- let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
- in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec))
-
- (Nothing, Just (_, 1)) -> -- array * scalar
- wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2)
-
- (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
- | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the strides check
- , strides1 == strides2
- -> -- dense * dense but the strides match
- let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
- arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2'
- in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec))
-
- (_, _) -> -- fallback case
- wrapBinaryVV sn ptrconv f_vv arr1 arr2
-
--- | Given shape vector, offset and stride vector, check whether this virtual
--- vector uses a dense subarray of its backing array. If so, the first index
--- and the number of elements in this subarray is returned.
--- This excludes any offset.
-stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
-stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0)
-stridesDense sh offsetNeg stridesNeg =
- -- First reverse all dimensions with negative stride, so that the first used
- -- value is at 'offset' and the rest is >= offset.
- let (offset, strides) = flipReverseds sh offsetNeg stridesNeg
- in -- sort dimensions on their stride, ascending, dropping any zero strides
- case filter ((/= 0) . fst) (sort (zip strides sh)) of
- [] -> Just (offset, 1)
- (1, n) : pairs -> (offset,) <$> checkCover n pairs
- _ -> Nothing -- if the smallest stride is not 1, it will never be dense
- where
- -- Given size of currently densely covered region at beginning of the
- -- array and the remaining (stride, size) pairs with all strides >=1,
- -- return whether this all together covers a dense prefix of the array. If
- -- it does, return the number of elements in this prefix.
- checkCover :: Int -> [(Int, Int)] -> Maybe Int
- checkCover block [] = Just block
- checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs
-
- -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
- flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
- flipReverseds [] off [] = (off, [])
- flipReverseds (n : sh') off (s : str')
- | s >= 0 = second (s :) (flipReverseds sh' off str')
- | otherwise =
- let off' = off + (n - 1) * s
- in second ((-s) :) (flipReverseds sh' off' str')
- flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
-
-{-# NOINLINE wrapBinarySV #-}
-wrapBinarySV :: Storable a
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
- -> a -> RS.Array n a
- -> RS.Array n a
-wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) =
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (product sh)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->
- cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv)
- RS.fromVector sh <$> VS.unsafeFreeze outv
-
-wrapBinaryVS :: Storable a
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
- -> RS.Array n a -> a
- -> RS.Array n a
-wrapBinaryVS sn valconv ptrconv cf_strided arr y =
- wrapBinarySV sn valconv ptrconv
- (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr
-
--- | This function assumes that the two shapes are equal.
-{-# NOINLINE wrapBinaryVV #-}
-wrapBinaryVV :: Storable a
- => SNat n
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
- -> RS.Array n a -> RS.Array n a
- -> RS.Array n a
-wrapBinaryVV sn@SNat ptrconv cf_strided
- (RS.A (RG.A sh (OI.T strides1 offset1 vec1)))
- (RS.A (RG.A _ (OI.T strides2 offset2 vec2))) =
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (product sh)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 ->
- VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 ->
- VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 ->
- cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2)
- RS.fromVector sh <$> VS.unsafeFreeze outv
-
-{-# NOINLINE vectorOp1 #-}
-vectorOp1 :: forall a b. Storable a
- => (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr b -> IO ())
- -> VS.Vector a -> VS.Vector a
-vectorOp1 ptrconv f v = unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length v)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith v $ \pv ->
- f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv)
- VS.unsafeFreeze outv
-
--- | If two vectors are given, assumes that they have the same length.
-{-# NOINLINE vectorOp2 #-}
-vectorOp2 :: forall a b. Storable a
- => (a -> b)
- -> (Ptr a -> Ptr b)
- -> (a -> a -> a)
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv
- -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs
- -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv
- -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a
-vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
- (Left x) (Left y) -> VS.singleton (fss x y)
-
- (Left x) (Right vy) ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vy)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vy $ \pvy ->
- fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy)
- VS.unsafeFreeze outv
-
- (Right vx) (Left y) ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vx)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vx $ \pvx ->
- fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y)
- VS.unsafeFreeze outv
-
- (Right vx) (Right vy)
- | VS.length vx == VS.length vy ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vx)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vx $ \pvx ->
- VS.unsafeWith vy $ \pvy ->
- fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy)
- VS.unsafeFreeze outv
- | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
-
--- TODO: test handling of negative strides
--- | Reduce along the inner dimension
-{-# NOINLINE vectorRedInnerOp #-}
-vectorRedInnerOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> RS.Array (n + 1) a -> RS.Array n a
-vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = error "unreachable"
- | last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0])
- | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty))
- -- now the input array is nonempty
- | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
- | last strides == 0 =
- liftVEltwise1 sn
- (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
- (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
- -- now there is useful work along the inner dimension
- | otherwise =
- let -- replicated dimensions: dimensions with zero stride. The reduction
- -- kernel need not concern itself with those (and in fact has a
- -- precondition that there are no such dimensions in its input).
- replDims = map (== 0) strides
- -- filter out replicated dimensions
- (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
- -- replace replicated dimensions with ones
- shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims
- ndimsF = length shF -- > 0, otherwise `last strides == 0`
-
- -- reversed dimensions: dimensions with negative stride. Reversal is
- -- irrelevant for a reduction, and indeed the kernel has a
- -- precondition that there are no such dimensions.
- revDims = map (< 0) stridesF
- stridesR = map abs stridesF
- offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
- -- The *R values give an array with strides all > 0, hence the
- -- left-most element is at offsetR.
- in unsafePerformIO $ do
- outvR <- VSM.unsafeNew (product (init shF))
- VSM.unsafeWith outvR $ \poutvR ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
- VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)
- TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
- RS.stretch (init sh) -- replicate to original shape
- . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated
- . RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions
- . RS.fromVector @_ @lenFm1 (init shF) -- the partially-reversed result array
- <$> VS.unsafeFreeze outvR
-
--- TODO: test handling of negative strides
--- | Reduce full array
-{-# NOINLINE vectorRedFullOp #-}
-vectorRedFullOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> Int -> a)
- -> (b -> a)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
- -> RS.Array n a -> a
-vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = vec VS.! offset -- 0D array has one element
- | any (<= 0) sh = 0
- -- now the input array is nonempty
- | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset
- -- now there is at least one non-replicated dimension
- | otherwise =
- let -- replicated dimensions: dimensions with zero stride. The reduction
- -- kernel need not concern itself with those (and in fact has a
- -- precondition that there are no such dimensions in its input).
- replDims = map (== 0) strides
- -- filter out replicated dimensions
- (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
- ndimsF = length shF -- > 0, otherwise `all (== 0) strides`
- -- we should scale up the output this many times to account for the replicated dimensions
- multiplier = product [n | (n, True) <- zip sh replDims]
-
- -- reversed dimensions: dimensions with negative stride. Reversal is
- -- irrelevant for a reduction, and indeed the kernel has a
- -- precondition that there are no such dimensions.
- revDims = map (< 0) stridesF
- stridesR = map abs stridesF
- offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
- -- The *R values give an array with strides all > 0, hence the
- -- left-most element is at offsetR.
- in unsafePerformIO $ do
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
- VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- (`scaleval` multiplier) . valbackconv
- <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)
-
--- TODO: test this function
--- | Find extremum (minindex ("argmin") or maxindex) in full array
-{-# NOINLINE vectorExtremumOp #-}
-vectorExtremumOp :: forall a b n. Storable a
- => (Ptr a -> Ptr b)
- -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
- -> RS.Array n a -> [Int] -- result length: n
-vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = []
- | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
- -- now the input array is nonempty
- | all (== 0) strides = 0 <$ sh
- -- now there is at least one non-replicated dimension
- | otherwise =
- let -- replicated dimensions: dimensions with zero stride. The extremum
- -- kernel need not concern itself with those (and in fact has a
- -- precondition that there are no such dimensions in its input).
- replDims = map (== 0) strides
- -- filter out replicated dimensions
- (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
- ndimsF = length shF -- > 0, because not all strides were <=0
-
- -- un-reverse reversed dimensions
- revDims = map (< 0) stridesF
- stridesR = map abs stridesF
- offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
-
- -- function to insert zeros in replicated-out dimensions
- insertZeros :: [Bool] -> [Int] -> [Int]
- insertZeros [] idx = idx
- insertZeros (True : repls) idx = 0 : insertZeros repls idx
- insertZeros (False : repls) (i : idx) = i : insertZeros repls idx
- insertZeros (_:_) [] = error "unreachable"
- in unsafePerformIO $ do
- outvR <- VSM.unsafeNew (length shF)
- VSM.unsafeWith outvR $ \poutvR ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
- VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)
- insertZeros replDims
- . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF -- re-reverse the reversed dimensions
- . map (fromIntegral @Int64 @Int)
- . VS.toList
- <$> VS.unsafeFreeze outvR
-
-vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
- -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
-vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
- arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
- arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | null sh1 || null sh2 = error "unreachable"
- | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0])
- | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty))
- -- now the input arrays are nonempty
- | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2)
- | last strides1 == 0 =
- fmul sn
- (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1)))
- (vectorRedInnerOp sn valconv ptrconv fscale fred arr2)
- | last strides2 == 0 =
- fmul sn
- (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
- (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2)))
- -- now there is useful dotprod work along the inner dimension
- | otherwise = unsafePerformIO $ do
- let inrank = fromSNat' sn + 1
- outv <- VSM.unsafeNew (product (init sh1))
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
- pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
- pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
- RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
-
-{-# NOINLINE dotScalarVector #-}
-dotScalarVector :: forall a b. (Num a, Storable a)
- => Int -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> a -> VS.Vector a -> a
-dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
- alloca @a $ \pout -> do
- alloca @Int64 $ \pshape -> do
- poke pshape (fromIntegral @Int @Int64 len)
- alloca @Int64 $ \pstride -> do
- poke pstride 1
- VS.unsafeWith vec $ \pvec ->
- fred 1 (ptrconv pout) pshape pstride (ptrconv pvec)
- res <- peek pout
- return (scalar * res)
-
-{-# NOINLINE dotVectorVector #-}
-dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
- -> VS.Vector a -> VS.Vector a -> a
-dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2)
-
-{-# NOINLINE dotVectorVectorStrided #-}
-dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
- -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel
- -> Int -> Int -> VS.Vector a
- -> Int -> Int -> VS.Vector a
- -> a
-dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- valbackconv <$> fdot (fromIntegral @Int @Int64 len)
- (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1)
- (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2)
-
-flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
- -> Int64 -> Ptr a -> Ptr a -> a -> IO ()
-flipOp f n out v s = f n out s v
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_binary_" ++ atCName arithtype
- c_ss_str = varE (aboNumOp arithop)
- c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM intTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_ibinary_" ++ atCName arithtype
- c_ss_str = varE (aiboNumOp arithop)
- c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
- c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
- c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_fbinary_" ++ atCName arithtype
- c_ss_str = varE (afboNumOp arithop)
- c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-mulWithInt :: Num a => a -> Int -> a
-mulWithInt a i = a * fromIntegral i
-
-scaleFromSVStrided :: (Int64 -> Ptr Int64 -> Ptr a -> a -> Ptr Int64 -> Ptr a -> IO ())
- -> Int64 -> Ptr a -> a -> Ptr a -> IO ()
-scaleFromSVStrided fsv n out x ys =
- VS.unsafeWith (VS.singleton n) $ \psh ->
- VS.unsafeWith (VS.singleton 1) $ \pstrides ->
- fsv 1 psh out x pstrides ys
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let scaleVar = case arithop of
- RO_SUM -> varE 'mulWithInt
- RO_PRODUCT -> varE '(^)
- let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
- namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
- c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
- c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
- c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
- sequence [SigD name1 <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> vectorRedInnerOp sn id id (scaleFromSVStrided $c_scale_op) $c_op1 |]
- return $ FunD name1 [Clause [] (NormalB body) []]
- ,SigD namefull <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |]
- ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
- return $ FunD namefull [Clause [] (NormalB body) []]
- ])
-
-$(fmap concat . forM typesList $ \arithtype ->
- fmap concat . forM ["min", "max"] $ \fname -> do
- let ttyp = conT (atType arithtype)
- name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
- sequence [SigD name <$>
- [t| forall n. RS.Array n $ttyp -> [Int] |]
- ,do body <- [| vectorExtremumOp id $c_op |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
- mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
- c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
- c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
-foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()
-
-statisticsEnable :: Bool -> IO ()
-statisticsEnable b = c_stats_enable (if b then 1 else 0)
-
--- | Consumes the log: one particular event will only ever be printed once,
--- even if statisticsPrintAll is called multiple times.
-statisticsPrintAll :: IO ()
-statisticsPrintAll = do
- hFlush stdout -- lower the chance of overlapping output
- c_stats_print_all
-
--- This branch is ostensibly a runtime branch, but will (hopefully) be
--- constant-folded away by GHC.
-intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
- => (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
- -> (SNat n -> RS.Array n i -> RS.Array n i)
-intWidBranch1 f32 f64 sn
- | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr castPtr f32
- | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr castPtr f64
- | otherwise = error "Unsupported Int width"
-
-intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
- => (i -> i -> i) -- ss
- -- int32
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- sv
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ()) -- vs
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- vv
- -- int64
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- sv
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
- -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
-intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
- | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32
- | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
- => -- int32
- (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
-intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
- | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
- => (i -> Int -> i) -- ^ scale op
- -- int32
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32) -- ^ reduction kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ reduction kernel
- -> (SNat n -> RS.Array n i -> i)
-intWidBranchRedFull fsc fred32 fred64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
- | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
- => -- int32
- (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ extremum kernel
- -- int64
- -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ extremum kernel
- -> (RS.Array n i -> [Int])
-intWidBranchExtr fextr32 fextr64
- | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
- | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
- => -- int32
- (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ dotprod kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i)
-intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
- | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
- | otherwise = error "Unsupported Int width"
-
-class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltSumFull :: SNat n -> RS.Array n a -> a
- numEltProductFull :: SNat n -> RS.Array n a -> a
- numEltMinIndex :: SNat n -> RS.Array n a -> [Int]
- numEltMaxIndex :: SNat n -> RS.Array n a -> [Int]
- numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
-
-instance NumElt Int32 where
- numEltAdd = addVectorInt32
- numEltSub = subVectorInt32
- numEltMul = mulVectorInt32
- numEltNeg = negVectorInt32
- numEltAbs = absVectorInt32
- numEltSignum = signumVectorInt32
- numEltSum1Inner = sum1VectorInt32
- numEltProduct1Inner = product1VectorInt32
- numEltSumFull = sumFullVectorInt32
- numEltProductFull = productFullVectorInt32
- numEltMinIndex _ = minindexVectorInt32
- numEltMaxIndex _ = maxindexVectorInt32
- numEltDotprodInner = dotprodinnerVectorInt32
-
-instance NumElt Int64 where
- numEltAdd = addVectorInt64
- numEltSub = subVectorInt64
- numEltMul = mulVectorInt64
- numEltNeg = negVectorInt64
- numEltAbs = absVectorInt64
- numEltSignum = signumVectorInt64
- numEltSum1Inner = sum1VectorInt64
- numEltProduct1Inner = product1VectorInt64
- numEltSumFull = sumFullVectorInt64
- numEltProductFull = productFullVectorInt64
- numEltMinIndex _ = minindexVectorInt64
- numEltMaxIndex _ = maxindexVectorInt64
- numEltDotprodInner = dotprodinnerVectorInt64
-
-instance NumElt Float where
- numEltAdd = addVectorFloat
- numEltSub = subVectorFloat
- numEltMul = mulVectorFloat
- numEltNeg = negVectorFloat
- numEltAbs = absVectorFloat
- numEltSignum = signumVectorFloat
- numEltSum1Inner = sum1VectorFloat
- numEltProduct1Inner = product1VectorFloat
- numEltSumFull = sumFullVectorFloat
- numEltProductFull = productFullVectorFloat
- numEltMinIndex _ = minindexVectorFloat
- numEltMaxIndex _ = maxindexVectorFloat
- numEltDotprodInner = dotprodinnerVectorFloat
-
-instance NumElt Double where
- numEltAdd = addVectorDouble
- numEltSub = subVectorDouble
- numEltMul = mulVectorDouble
- numEltNeg = negVectorDouble
- numEltAbs = absVectorDouble
- numEltSignum = signumVectorDouble
- numEltSum1Inner = sum1VectorDouble
- numEltProduct1Inner = product1VectorDouble
- numEltSumFull = sumFullVectorDouble
- numEltProductFull = productFullVectorDouble
- numEltMinIndex _ = minindexVectorDouble
- numEltMaxIndex _ = maxindexVectorDouble
- numEltDotprodInner = dotprodinnerVectorDouble
-
-instance NumElt Int where
- numEltAdd = intWidBranch2 @Int (+)
- (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
- (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
- numEltSub = intWidBranch2 @Int (-)
- (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
- (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
- numEltMul = intWidBranch2 @Int (*)
- (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
- (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed1 @Int
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM))
- numEltProduct1Inner = intWidBranchRed1 @Int
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT))
- numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
- numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
- numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
- numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
- numEltDotprodInner = intWidBranchDotprod @Int (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
-
-instance NumElt CInt where
- numEltAdd = intWidBranch2 @CInt (+)
- (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
- (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
- numEltSub = intWidBranch2 @CInt (-)
- (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
- (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
- numEltMul = intWidBranch2 @CInt (*)
- (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
- (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed1 @CInt
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM))
- numEltProduct1Inner = intWidBranchRed1 @CInt
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT))
- numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
- numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
- numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
- numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
- numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
-
-class NumElt a => IntElt a where
- intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
-
-instance IntElt Int32 where
- intEltQuot = quotVectorInt32
- intEltRem = remVectorInt32
-
-instance IntElt Int64 where
- intEltQuot = quotVectorInt64
- intEltRem = remVectorInt64
-
-instance IntElt Int where
- intEltQuot = intWidBranch2 @Int quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
- intEltRem = intWidBranch2 @Int rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
-
-instance IntElt CInt where
- intEltQuot = intWidBranch2 @CInt quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
- intEltRem = intWidBranch2 @CInt rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
-
-class NumElt a => FloatElt a where
- floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
-
-instance FloatElt Float where
- floatEltDiv = divVectorFloat
- floatEltPow = powVectorFloat
- floatEltLogbase = logbaseVectorFloat
- floatEltRecip = recipVectorFloat
- floatEltExp = expVectorFloat
- floatEltLog = logVectorFloat
- floatEltSqrt = sqrtVectorFloat
- floatEltSin = sinVectorFloat
- floatEltCos = cosVectorFloat
- floatEltTan = tanVectorFloat
- floatEltAsin = asinVectorFloat
- floatEltAcos = acosVectorFloat
- floatEltAtan = atanVectorFloat
- floatEltSinh = sinhVectorFloat
- floatEltCosh = coshVectorFloat
- floatEltTanh = tanhVectorFloat
- floatEltAsinh = asinhVectorFloat
- floatEltAcosh = acoshVectorFloat
- floatEltAtanh = atanhVectorFloat
- floatEltLog1p = log1pVectorFloat
- floatEltExpm1 = expm1VectorFloat
- floatEltLog1pexp = log1pexpVectorFloat
- floatEltLog1mexp = log1mexpVectorFloat
- floatEltAtan2 = atan2VectorFloat
-
-instance FloatElt Double where
- floatEltDiv = divVectorDouble
- floatEltPow = powVectorDouble
- floatEltLogbase = logbaseVectorDouble
- floatEltRecip = recipVectorDouble
- floatEltExp = expVectorDouble
- floatEltLog = logVectorDouble
- floatEltSqrt = sqrtVectorDouble
- floatEltSin = sinVectorDouble
- floatEltCos = cosVectorDouble
- floatEltTan = tanVectorDouble
- floatEltAsin = asinVectorDouble
- floatEltAcos = acosVectorDouble
- floatEltAtan = atanVectorDouble
- floatEltSinh = sinhVectorDouble
- floatEltCosh = coshVectorDouble
- floatEltTanh = tanhVectorDouble
- floatEltAsinh = asinhVectorDouble
- floatEltAcosh = acoshVectorDouble
- floatEltAtanh = atanhVectorDouble
- floatEltLog1p = log1pVectorDouble
- floatEltExpm1 = expm1VectorDouble
- floatEltLog1pexp = log1pexpVectorDouble
- floatEltLog1mexp = log1mexpVectorDouble
- floatEltAtan2 = atan2VectorDouble
+liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
+ -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
+liftO2 f x y = toO (f (fromO x) (fromO y))