aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-20 13:01:24 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-20 13:01:24 +0100
commit55036a5ea4a6e590d0404638b2823c6a4aec3fba (patch)
tree484bc377229d3edff36bd9a2a80f999bbcd2e889
parent5414434df62b2b196354b9748b265093c168601b (diff)
Separate arith routines into a library
The point is that this separate library does not depend on orthotope.
-rw-r--r--bench/Main.hs2
-rw-r--r--cbits/arith.c40
-rw-r--r--ops/Data/Array/Strided.hs7
-rw-r--r--ops/Data/Array/Strided/Arith.hs7
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs866
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Foreign.hs (renamed from src/Data/Array/Mixed/Internal/Arith/Foreign.hs)4
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs (renamed from src/Data/Array/Mixed/Internal/Arith/Lists.hs)4
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs (renamed from src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs)2
-rw-r--r--ops/Data/Array/Strided/Array.hs42
-rw-r--r--ox-arrays.cabal27
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs928
-rw-r--r--src/Data/Array/Mixed/XArray.hs5
-rw-r--r--src/Data/Array/Nested.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs75
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs2
16 files changed, 1028 insertions, 987 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 6e83270..3ab81a8 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -12,7 +12,7 @@ import Test.Tasty.Bench
import Data.Array.Nested
import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2)
import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
-import qualified Data.Array.Mixed.Internal.Arith as Arith
+import qualified Data.Array.Strided.Arith.Internal as Arith
enableMisc :: Bool
diff --git a/cbits/arith.c b/cbits/arith.c
index ca0af51..b574d54 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -24,6 +24,17 @@ typedef int32_t i32;
typedef int64_t i64;
+// PRECONDITIONS
+//
+// All strided array operations in this file assume that none of the shape
+// components are zero -- that is, the input arrays are non-empty. This must
+// be arranged on the Haskell side.
+//
+// Furthermore, note that while the Haskell side has an offset into the backing
+// vector, the C side assumes that the offset is zero. Shift the pointer if
+// necessary.
+
+
/*****************************************************************************
* Performance statistics *
*****************************************************************************/
@@ -370,6 +381,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
#define COMM_OP_STRIDED(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ if (rank == 0) { out[0] = x op y[0]; return; } \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \
@@ -377,6 +389,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
} \
static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ if (rank == 0) { out[0] = x[0] op y[0]; return; } \
TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \
@@ -387,6 +400,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
#define NONCOMM_OP_STRIDED(name, op, typ) \
COMM_OP_STRIDED(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ if (rank == 0) { out[0] = x[0] op y; return; } \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \
@@ -396,6 +410,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
#define PREFIX_BINOP_STRIDED(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ if (rank == 0) { out[0] = op(x, y[0]); return; } \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \
@@ -403,6 +418,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
} \
static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ if (rank == 0) { out[0] = op(x[0], y[0]); return; } \
TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \
@@ -410,6 +426,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
} \
static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ if (rank == 0) { out[0] = op(x[0], y); return; } \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \
@@ -424,6 +441,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
fprintf(stderr, " strides="); \
print_shape(stderr, rank, strides); \
fprintf(stderr, "\n"); */ \
+ if (rank == 0) { out[0] = op(arr[0]); return; } \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \
@@ -434,7 +452,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// Used for reduction and dot product kernels below
#define MANUAL_VECT_WID 8
-// Used in REDUCE1_OP and REDUCEFULL_OP below; requires the same preconditions
+// Used in REDUCE1_OP and REDUCEFULL_OP below
#define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \
do { \
const i64 n = innerLen; const i64 s = innerStride; \
@@ -458,11 +476,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
} \
} while (0)
-// preconditions:
-// - all strides are >0
-// - shape is everywhere >0
-// - rank is >= 1
-// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
@@ -472,12 +485,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
}
-// preconditions
-// - all strides are >0
-// - shape is everywhere >0
-// - rank is >= 1
#define REDUCEFULL_OP(name, op, typ) \
typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ if (rank == 0) return arr[0]; \
typ result = 0; \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \
@@ -485,13 +495,10 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
return result; \
}
-// preconditions
-// - all strides are >0
-// - shape is everywhere >0
-// - rank is >= 1
// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
#define EXTREMUM_OP(name, cmp, typ) \
void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ if (rank == 0) return; /* output index vector has length 0 anyways */ \
typ best = arr[0]; \
memset(outidx, 0, rank * sizeof(i64)); \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
@@ -527,11 +534,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
} \
}
-// preconditions:
-// - all strides are >0
-// - shape is everywhere >0
-// - rank is >= 1
-// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define DOTPROD_INNER_OP(typ) \
diff --git a/ops/Data/Array/Strided.hs b/ops/Data/Array/Strided.hs
new file mode 100644
index 0000000..a0506a9
--- /dev/null
+++ b/ops/Data/Array/Strided.hs
@@ -0,0 +1,7 @@
+module Data.Array.Strided (
+ module Data.Array.Strided.Array,
+ module Data.Array.Strided.Arith,
+) where
+
+import Data.Array.Strided.Array
+import Data.Array.Strided.Arith
diff --git a/ops/Data/Array/Strided/Arith.hs b/ops/Data/Array/Strided/Arith.hs
new file mode 100644
index 0000000..7be6390
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith.hs
@@ -0,0 +1,7 @@
+module Data.Array.Strided.Arith (
+ NumElt(..),
+ IntElt(..),
+ FloatElt(..),
+) where
+
+import Data.Array.Strided.Arith.Internal
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
new file mode 100644
index 0000000..fe0fc4b
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -0,0 +1,866 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Strided.Arith.Internal where
+
+import Control.Monad
+import Data.Bifunctor (second)
+import Data.Bits
+import Data.Int
+import Data.List (sort)
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.Ptr
+import Foreign.Storable
+import qualified GHC.TypeNats as TypeNats
+import GHC.TypeLits
+import Language.Haskell.TH
+import System.IO (hFlush, stdout)
+import System.IO.Unsafe
+
+import Data.Array.Strided.Array
+import Data.Array.Strided.Arith.Internal.Lists
+import Data.Array.Strided.Arith.Internal.Foreign
+
+
+-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
+
+
+-- TODO: move this to a utilities module
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
+data Dict c where
+ Dict :: c => Dict c
+
+debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String
+debugShow (Array sh strides offset vec) =
+ "Array @" ++ (show (natVal (Proxy @n))) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">"
+
+
+-- TODO: test all the cases of this thing with various input strides
+liftOpEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a -> Array n a
+liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec)
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ if blockSz == 0
+ then Array sh (map (const 0) strides) 0 VS.empty
+ else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [fromIntegral blockSz] [1] blockOff vec)
+ in Array sh strides (offset - blockOff) resvec
+ | otherwise = wrapUnary sn ptrconv cf_strided arr
+
+-- TODO: test all the cases of this thing with various input strides
+liftOpEltwise2 :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv
+ -> Array n a -> Array n a -> Array n a
+liftOpEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
+ arr1@(Array sh1 strides1 offset1 vec1)
+ arr2@(Array sh2 strides2 offset2 vec2)
+ | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | any (<= 0) sh1 = Array sh1 (0 <$ strides1) 0 VS.empty
+ | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
+ (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))
+ in Array sh1 strides1 0 vec'
+
+ (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
+ let arr2' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec2)
+ resvec = arrValues $ wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
+ in Array sh1 strides2 (offset2 - blockOff) resvec
+
+ (Just (_, 1), Nothing) -> -- scalar * array
+ wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2
+
+ (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
+ let arr1' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec1)
+ resvec = arrValues $ wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
+ in Array sh1 strides1 (offset1 - blockOff) resvec
+
+ (Nothing, Just (_, 1)) -> -- array * scalar
+ wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2)
+
+ (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
+ | strides1 == strides2
+ -> -- dense * dense but the strides match
+ if blockSz1 /= blockSz2 || offset1 - blockOff1 /= offset2 - blockOff2
+ then error $ "Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " ++ show (strides1, (blockOff1, blockSz1), strides2, (blockOff2, blockSz2))
+ else
+ let arr1' = arrayFromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
+ arr2' = arrayFromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
+ resvec = arrValues $ wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2'
+ in Array sh1 strides1 (offset1 - blockOff1) resvec
+
+ (_, _) -> -- fallback case
+ wrapBinaryVV sn ptrconv f_vv arr1 arr2
+
+-- | Given shape vector, offset and stride vector, check whether this virtual
+-- vector uses a dense subarray of its backing array. If so, the first index
+-- and the number of elements in this subarray is returned.
+-- This excludes any offset.
+stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
+stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0)
+stridesDense sh offsetNeg stridesNeg =
+ -- First reverse all dimensions with negative stride, so that the first used
+ -- value is at 'offset' and the rest is >= offset.
+ let (offset, strides) = flipReverseds sh offsetNeg stridesNeg
+ in -- sort dimensions on their stride, ascending, dropping any zero strides
+ case filter ((/= 0) . fst) (sort (zip strides sh)) of
+ [] -> Just (offset, 1)
+ (1, n) : pairs -> (offset,) <$> checkCover n pairs
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
+ where
+ -- Given size of currently densely covered region at beginning of the
+ -- array and the remaining (stride, size) pairs with all strides >=1,
+ -- return whether this all together covers a dense prefix of the array. If
+ -- it does, return the number of elements in this prefix.
+ checkCover :: Int -> [(Int, Int)] -> Maybe Int
+ checkCover block [] = Just block
+ checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs
+
+ -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
+ flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
+ flipReverseds [] off [] = (off, [])
+ flipReverseds (n : sh') off (s : str')
+ | s >= 0 = second (s :) (flipReverseds sh' off str')
+ | otherwise =
+ let off' = off + (n - 1) * s
+ in second ((-s) :) (flipReverseds sh' off' str')
+ flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
+
+data Unreplicated a =
+ forall n'. KnownNat n' =>
+ -- | Let the original array, with replicated dimensions, be called A.
+ Unreplicated -- | An array with all strides /= 0. Call this array U. It has
+ -- the same shape as A, except with all the replicated (stride
+ -- == 0) dimensions removed. The shape of U is the
+ -- "unreplicated shape".
+ (Array n' a)
+ -- | Product of sizes of the unreplicated dimensions
+ Int
+ -- | Given the stride vector of an array with the unreplicated
+ -- shape, this function reinserts zeros so that it may be
+ -- combined with the original shape of A.
+ ([Int] -> [Int])
+
+-- | Removes all replicated dimensions (i.e. those with stride == 0) from the array.
+unreplicateStrides :: Array n a -> Unreplicated a
+unreplicateStrides (Array sh strides offset vec) =
+ let replDims = map (== 0) strides
+ (shF, stridesF) = unzip [(n, s) | (n, s) <- zip sh strides, s /= 0]
+
+ reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
+ reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
+ reinsertZeros [] [] = []
+ reinsertZeros (False : _) [] = error $ "unreplicateStrides: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error $ "unreplicateStrides: Internal error: reply strides too long"
+
+ unrepSize = product [n | (n, True) <- zip sh replDims]
+
+ in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
+
+simplifyArray :: Array n a
+ -> (forall n'. KnownNat n'
+ => Array n' a -- U
+ -- Product of sizes of the unreplicated dimensions
+ -> Int
+ -- Convert index in U back to index into original
+ -- array. Replicated dimensions get 0.
+ -> ([Int] -> [Int])
+ -- Given a new array of the same shape as U, convert
+ -- it back to the original shape and iteration order.
+ -> (Array n' a -> Array n a)
+ -- Do the same except without the INNER dimension.
+ -- This throws an error if the inner dimension had
+ -- stride 0.
+ -> (Array (n' - 1) a -> Array (n - 1) a)
+ -> r)
+ -> r
+simplifyArray array k
+ | let revDims = map (<0) (arrStrides array)
+ , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
+ = k array'
+ unrepSize
+ (\idx -> rereplicate (zipWith3 (\b n i -> if b then n - 1 - i else i)
+ revDims (arrShape array') idx))
+ (\(Array sh' strides' offset' vec') ->
+ if sh' == arrShape array'
+ then arrayRevDims revDims (Array (arrShape array) (rereplicate strides') offset' vec')
+ else error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")")
+ (\(Array sh' strides' offset' vec') ->
+ if | sh' /= init (arrShape array') ->
+ error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")"
+ | last (arrStrides array) == 0 ->
+ error $ "simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
+ | otherwise ->
+ arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec'))
+
+{-# NOINLINE wrapUnary #-}
+wrapUnary :: forall a b n. Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a
+ -> Array n a
+wrapUnary _ ptrconv cf_strided array =
+ simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do
+ let ndims' = length sh
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral ndims') (ptrconv poutv) psh pstrides pv'
+ restore . arrayFromVector sh <$> VS.unsafeFreeze outv
+
+{-# NOINLINE wrapBinarySV #-}
+wrapBinarySV :: forall a b n. Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
+ -> a -> Array n a
+ -> Array n a
+wrapBinarySV SNat valconv ptrconv cf_strided x array =
+ simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do
+ let ndims' = length sh
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral ndims') psh (ptrconv poutv) (valconv x) pstrides pv'
+ restore . arrayFromVector sh <$> VS.unsafeFreeze outv
+
+wrapBinaryVS :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
+ -> Array n a -> a
+ -> Array n a
+wrapBinaryVS sn valconv ptrconv cf_strided arr y =
+ wrapBinarySV sn valconv ptrconv
+ (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr
+
+-- | The two shapes must be equal and non-empty. This is checked.
+{-# NOINLINE wrapBinaryVV #-}
+wrapBinaryVV :: forall a b n. Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a -> Array n a
+ -> Array n a
+-- TODO: do unreversing and unreplication on the input arrays (but
+-- simultaneously: can only unreplicate if _both_ are replicated on that
+-- dimension)
+wrapBinaryVV sn@SNat ptrconv cf_strided
+ (Array sh strides1 offset1 vec1)
+ (Array sh2 strides2 offset2 vec2)
+ | sh /= sh2 = error $ "wrapBinaryVV: unequal shapes: " ++ show sh ++ " and " ++ show sh2
+ | any (<= 0) sh = error $ "wrapBinaryVV: empty shape: " ++ show sh
+ | otherwise = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith vec1 $ \pv1 ->
+ VS.unsafeWith vec2 $ \pv2 ->
+ let pv1' = pv1 `plusPtr` (offset1 * sizeOf (undefined :: a))
+ pv2' = pv2 `plusPtr` (offset2 * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 pv1' pstrides2 pv2'
+ arrayFromVector sh <$> VS.unsafeFreeze outv
+
+-- TODO: test handling of negative strides
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> Array (n + 1) a -> Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides offset vec)
+ | null sh = error "unreachable"
+ | last sh <= 0 = arrayFromConstant (init sh) 0
+ | any (<= 0) (init sh) = Array (init sh) (0 <$ init strides) 0 VS.empty
+ -- now the input array is nonempty
+ | last sh == 1 = Array (init sh) (init strides) offset vec
+ | last strides == 0 =
+ wrapBinarySV sn valconv ptrconv fscale (fromIntegral @Int @a (last sh))
+ (Array (init sh) (init strides) offset vec)
+ -- now there is useful work along the inner dimension
+ -- Note that unreplication keeps the inner dimension intact, because `last strides /= 0` at this point.
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec' :: Array n' a) _ _ _ restore -> unsafePerformIO $ do
+ let ndims' = length sh'
+ outv <- VSM.unsafeNew (product (init sh'))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
+ TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
+ LTI -> pure Dict
+ EQI -> pure Dict
+ _ -> error "impossible" -- because `last strides /= 0`
+ case sameNat (natSing @(n' - 1)) (natSing @n'm1) of
+ Just Refl -> restore . arrayFromVector @_ @n'm1 (init sh') <$> VS.unsafeFreeze outv
+ Nothing -> error "impossible"
+
+-- TODO: test handling of negative strides
+-- | Reduce full array
+{-# NOINLINE vectorRedFullOp #-}
+vectorRedFullOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> Int -> a)
+ -> (b -> a)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> Array n a -> a
+vectorRedFullOp _ scaleval valbackconv ptrconv fred array@(Array sh strides offset vec)
+ | null sh = vec VS.! offset -- 0D array has one element
+ | any (<= 0) sh = 0
+ -- now the input array is nonempty
+ | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec') unrepSize _ _ _ -> unsafePerformIO $ do
+ let ndims' = length sh'
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in (`scaleval` unrepSize) . valbackconv
+ <$> fred (fromIntegral ndims') psh pstrides (ptrconv pv')
+
+-- TODO: test this function
+-- | Find extremum (minindex ("argmin") or maxindex) in full array
+{-# NOINLINE vectorExtremumOp #-}
+vectorExtremumOp :: forall a b n. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> Array n a -> [Int] -- result length: n
+vectorExtremumOp ptrconv fextrem array@(Array sh strides _ _)
+ | null sh = []
+ | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
+ -- now the input array is nonempty
+ | all (== 0) strides = 0 <$ sh
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec') _ upindex _ _ -> unsafePerformIO $ do
+ let ndims' = length sh'
+ outvR <- VSM.unsafeNew (length sh')
+ VSM.unsafeWith outvR $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in fextrem poutv (fromIntegral ndims') psh pstrides (ptrconv pv')
+ upindex . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outvR
+
+{-# NOINLINE vectorDotprodInnerOp #-}
+vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (SNat n -> Array n a -> Array n a -> Array n a) -- ^ elementwise multiplication
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
+ arr1@(Array sh1 strides1 offset1 vec1)
+ arr2@(Array sh2 strides2 offset2 vec2)
+ | null sh1 || null sh2 = error "unreachable"
+ | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | last sh1 <= 0 = arrayFromConstant (init sh1) 0
+ | any (<= 0) (init sh1) = Array (init sh1) (0 <$ init strides1) 0 VS.empty
+ -- now the input arrays are nonempty
+ | last sh1 == 1 =
+ fmul sn (Array (init sh1) (init strides1) offset1 vec1)
+ (Array (init sh2) (init strides2) offset2 vec2)
+ | last strides1 == 0 =
+ fmul sn
+ (Array (init sh1) (init strides1) offset1 vec1)
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr2)
+ | last strides2 == 0 =
+ fmul sn
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
+ (Array (init sh2) (init strides2) offset2 vec2)
+ -- now there is useful dotprod work along the inner dimension
+ | otherwise = unsafePerformIO $ do
+ let inrank = fromSNat' sn + 1
+ outv <- VSM.unsafeNew (product (init sh1))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
+ pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
+ pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
+ arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
+
+mulWithInt :: Num a => a -> Int -> a
+mulWithInt a i = a * fromIntegral i
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss_str = varE (aboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM intTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_ibinary_" ++ atCName arithtype
+ c_ss_str = varE (aiboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss_str = varE (afboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let scaleVar = case arithop of
+ RO_SUM -> varE 'mulWithInt
+ RO_PRODUCT -> varE '(^)
+ let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
+ namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
+ c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ sequence [SigD name1 <$>
+ [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
+ return $ FunD name1 [Clause [] (NormalB body) []]
+ ,SigD namefull <$>
+ [t| forall n. SNat n -> Array n $ttyp -> $ttyp |]
+ ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
+ return $ FunD namefull [Clause [] (NormalB body) []]
+ ])
+
+$(fmap concat . forM typesList $ \arithtype ->
+ fmap concat . forM ["min", "max"] $ \fname -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
+ sequence [SigD name <$>
+ [t| forall n. Array n $ttyp -> [Int] |]
+ ,do body <- [| vectorExtremumOp id $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
+ mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
+foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()
+
+statisticsEnable :: Bool -> IO ()
+statisticsEnable b = c_stats_enable (if b then 1 else 0)
+
+-- | Consumes the log: one particular event will only ever be printed once,
+-- even if statisticsPrintAll is called multiple times.
+statisticsPrintAll :: IO ()
+statisticsPrintAll = do
+ hFlush stdout -- lower the chance of overlapping output
+ c_stats_print_all
+
+-- This branch is ostensibly a runtime branch, but will (hopefully) be
+-- constant-folded away by GHC.
+intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
+ => (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> (SNat n -> Array n i -> Array n i)
+intWidBranch1 f32 f64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr f32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr f64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> i -> i) -- ss
+ -- int32
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv
+ -> (SNat n -> Array n i -> Array n i -> Array n i)
+intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> Array (n + 1) i -> Array n i)
+intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> Int -> i) -- ^ scale op
+ -- int32
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> (SNat n -> Array n i -> i)
+intWidBranchRedFull fsc fred32 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> (Array n i -> [Int])
+intWidBranchExtr fextr32 fextr64
+ | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
+ | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
+ => -- int32
+ (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i)
+intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
+ | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
+ | otherwise = error "Unsupported Int width"
+
+class NumElt a where
+ numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltSub :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltMul :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltNeg :: SNat n -> Array n a -> Array n a
+ numEltAbs :: SNat n -> Array n a -> Array n a
+ numEltSignum :: SNat n -> Array n a -> Array n a
+ numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a
+ numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a
+ numEltSumFull :: SNat n -> Array n a -> a
+ numEltProductFull :: SNat n -> Array n a -> a
+ numEltMinIndex :: SNat n -> Array n a -> [Int]
+ numEltMaxIndex :: SNat n -> Array n a -> [Int]
+ numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+
+instance NumElt Int32 where
+ numEltAdd = addVectorInt32
+ numEltSub = subVectorInt32
+ numEltMul = mulVectorInt32
+ numEltNeg = negVectorInt32
+ numEltAbs = absVectorInt32
+ numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
+ numEltSumFull = sumFullVectorInt32
+ numEltProductFull = productFullVectorInt32
+ numEltMinIndex _ = minindexVectorInt32
+ numEltMaxIndex _ = maxindexVectorInt32
+ numEltDotprodInner = dotprodinnerVectorInt32
+
+instance NumElt Int64 where
+ numEltAdd = addVectorInt64
+ numEltSub = subVectorInt64
+ numEltMul = mulVectorInt64
+ numEltNeg = negVectorInt64
+ numEltAbs = absVectorInt64
+ numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
+ numEltSumFull = sumFullVectorInt64
+ numEltProductFull = productFullVectorInt64
+ numEltMinIndex _ = minindexVectorInt64
+ numEltMaxIndex _ = maxindexVectorInt64
+ numEltDotprodInner = dotprodinnerVectorInt64
+
+instance NumElt Float where
+ numEltAdd = addVectorFloat
+ numEltSub = subVectorFloat
+ numEltMul = mulVectorFloat
+ numEltNeg = negVectorFloat
+ numEltAbs = absVectorFloat
+ numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
+ numEltSumFull = sumFullVectorFloat
+ numEltProductFull = productFullVectorFloat
+ numEltMinIndex _ = minindexVectorFloat
+ numEltMaxIndex _ = maxindexVectorFloat
+ numEltDotprodInner = dotprodinnerVectorFloat
+
+instance NumElt Double where
+ numEltAdd = addVectorDouble
+ numEltSub = subVectorDouble
+ numEltMul = mulVectorDouble
+ numEltNeg = negVectorDouble
+ numEltAbs = absVectorDouble
+ numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
+ numEltSumFull = sumFullVectorDouble
+ numEltProductFull = productFullVectorDouble
+ numEltMinIndex _ = minindexVectorDouble
+ numEltMaxIndex _ = maxindexVectorDouble
+ numEltDotprodInner = dotprodinnerVectorDouble
+
+instance NumElt Int where
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
+ numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+
+instance NumElt CInt where
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
+ numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+
+class NumElt a => IntElt a where
+ intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
+ intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+
+instance IntElt Int32 where
+ intEltQuot = quotVectorInt32
+ intEltRem = remVectorInt32
+
+instance IntElt Int64 where
+ intEltQuot = quotVectorInt64
+ intEltRem = remVectorInt64
+
+instance IntElt Int where
+ intEltQuot = intWidBranch2 @Int quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @Int rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+instance IntElt CInt where
+ intEltQuot = intWidBranch2 @CInt quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @CInt rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+class NumElt a => FloatElt a where
+ floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltRecip :: SNat n -> Array n a -> Array n a
+ floatEltExp :: SNat n -> Array n a -> Array n a
+ floatEltLog :: SNat n -> Array n a -> Array n a
+ floatEltSqrt :: SNat n -> Array n a -> Array n a
+ floatEltSin :: SNat n -> Array n a -> Array n a
+ floatEltCos :: SNat n -> Array n a -> Array n a
+ floatEltTan :: SNat n -> Array n a -> Array n a
+ floatEltAsin :: SNat n -> Array n a -> Array n a
+ floatEltAcos :: SNat n -> Array n a -> Array n a
+ floatEltAtan :: SNat n -> Array n a -> Array n a
+ floatEltSinh :: SNat n -> Array n a -> Array n a
+ floatEltCosh :: SNat n -> Array n a -> Array n a
+ floatEltTanh :: SNat n -> Array n a -> Array n a
+ floatEltAsinh :: SNat n -> Array n a -> Array n a
+ floatEltAcosh :: SNat n -> Array n a -> Array n a
+ floatEltAtanh :: SNat n -> Array n a -> Array n a
+ floatEltLog1p :: SNat n -> Array n a -> Array n a
+ floatEltExpm1 :: SNat n -> Array n a -> Array n a
+ floatEltLog1pexp :: SNat n -> Array n a -> Array n a
+ floatEltLog1mexp :: SNat n -> Array n a -> Array n a
+ floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltPow = powVectorFloat
+ floatEltLogbase = logbaseVectorFloat
+ floatEltRecip = recipVectorFloat
+ floatEltExp = expVectorFloat
+ floatEltLog = logVectorFloat
+ floatEltSqrt = sqrtVectorFloat
+ floatEltSin = sinVectorFloat
+ floatEltCos = cosVectorFloat
+ floatEltTan = tanVectorFloat
+ floatEltAsin = asinVectorFloat
+ floatEltAcos = acosVectorFloat
+ floatEltAtan = atanVectorFloat
+ floatEltSinh = sinhVectorFloat
+ floatEltCosh = coshVectorFloat
+ floatEltTanh = tanhVectorFloat
+ floatEltAsinh = asinhVectorFloat
+ floatEltAcosh = acoshVectorFloat
+ floatEltAtanh = atanhVectorFloat
+ floatEltLog1p = log1pVectorFloat
+ floatEltExpm1 = expm1VectorFloat
+ floatEltLog1pexp = log1pexpVectorFloat
+ floatEltLog1mexp = log1mexpVectorFloat
+ floatEltAtan2 = atan2VectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltPow = powVectorDouble
+ floatEltLogbase = logbaseVectorDouble
+ floatEltRecip = recipVectorDouble
+ floatEltExp = expVectorDouble
+ floatEltLog = logVectorDouble
+ floatEltSqrt = sqrtVectorDouble
+ floatEltSin = sinVectorDouble
+ floatEltCos = cosVectorDouble
+ floatEltTan = tanVectorDouble
+ floatEltAsin = asinVectorDouble
+ floatEltAcos = acosVectorDouble
+ floatEltAtan = atanVectorDouble
+ floatEltSinh = sinhVectorDouble
+ floatEltCosh = coshVectorDouble
+ floatEltTanh = tanhVectorDouble
+ floatEltAsinh = asinhVectorDouble
+ floatEltAcosh = acoshVectorDouble
+ floatEltAtanh = atanhVectorDouble
+ floatEltLog1p = log1pVectorDouble
+ floatEltExpm1 = expm1VectorDouble
+ floatEltLog1pexp = log1pexpVectorDouble
+ floatEltLog1mexp = log1mexpVectorDouble
+ floatEltAtan2 = atan2VectorDouble
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs
index 78d5365..dad65f9 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs
@@ -1,13 +1,13 @@
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Mixed.Internal.Arith.Foreign where
+module Data.Array.Strided.Arith.Internal.Foreign where
import Data.Int
import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
-import Data.Array.Mixed.Internal.Arith.Lists
+import Data.Array.Strided.Arith.Internal.Lists
$(do
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 370b708..910a77c 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -1,12 +1,12 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Mixed.Internal.Arith.Lists where
+module Data.Array.Strided.Arith.Internal.Lists where
import Data.Char
import Data.Int
import Language.Haskell.TH
-import Data.Array.Mixed.Internal.Arith.Lists.TH
+import Data.Array.Strided.Arith.Internal.Lists.TH
data ArithType = ArithType
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs
index a156e29..b8f6a3d 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE TemplateHaskellQuotes #-}
-module Data.Array.Mixed.Internal.Arith.Lists.TH where
+module Data.Array.Strided.Arith.Internal.Lists.TH where
import Control.Monad
import Control.Monad.IO.Class
diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs
new file mode 100644
index 0000000..a772aaf
--- /dev/null
+++ b/ops/Data/Array/Strided/Array.hs
@@ -0,0 +1,42 @@
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+module Data.Array.Strided.Array where
+
+import qualified Data.List.NonEmpty as NE
+import Data.Proxy
+import qualified Data.Vector.Storable as VS
+import Foreign.Storable
+import GHC.TypeLits
+
+
+data Array (n :: Nat) a = Array
+ { arrShape :: ![Int]
+ , arrStrides :: ![Int]
+ , arrOffset :: !Int
+ , arrValues :: !(VS.Vector a)
+ }
+
+-- | Takes a vector in normalised order (inner dimension, i.e. last in the
+-- list, iterates fastest).
+arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a
+arrayFromVector sh vec
+ | VS.length vec == shsize
+ , length sh == fromIntegral (natVal (Proxy @n))
+ = Array sh strides 0 vec
+ | otherwise = error $ "arrayFromVector: Shape " ++ show sh ++ " does not match vector length " ++ show (VS.length vec)
+ where
+ shsize = product sh
+ strides = NE.tail (NE.scanr (*) 1 sh)
+
+arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a
+arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x)
+
+arrayRevDims :: [Bool] -> Array n a -> Array n a
+arrayRevDims bs (Array sh strides offset vec)
+ | length bs == length sh =
+ Array sh
+ (zipWith (\b s -> if b then -s else s) bs strides)
+ (offset + sum (zipWith3 (\b n s -> if b then (n - 1) * s else 0) bs sh strides))
+ vec
+ | otherwise = error $ "arrayRevDims: " ++ show (length bs) ++ " booleans given but rank " ++ show (length sh)
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 19a61ab..ecd3ba7 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -30,9 +30,6 @@ flag nonportable-simd
library
exposed-modules:
Data.Array.Mixed.Internal.Arith
- Data.Array.Mixed.Internal.Arith.Foreign
- Data.Array.Mixed.Internal.Arith.Lists
- Data.Array.Mixed.Internal.Arith.Lists.TH
Data.Array.Mixed.Lemmas
Data.Array.Mixed.Permutation
Data.Array.Mixed.Shape
@@ -52,6 +49,8 @@ library
Data.Array.Nested.Trace.TH
build-depends:
+ strided-array-ops,
+
base >=4.18 && <4.21,
deepseq,
ghc-typelits-knownnat,
@@ -60,6 +59,27 @@ library
template-haskell,
vector
hs-source-dirs: src
+
+ default-language: Haskell2010
+ ghc-options: -Wall
+ other-extensions: TemplateHaskell
+
+library strided-array-ops
+ exposed-modules:
+ Data.Array.Strided
+ Data.Array.Strided.Array
+ Data.Array.Strided.Arith
+ Data.Array.Strided.Arith.Internal
+ Data.Array.Strided.Arith.Internal.Foreign
+ Data.Array.Strided.Arith.Internal.Lists
+ Data.Array.Strided.Arith.Internal.Lists.TH
+ build-depends:
+ base,
+ ghc-typelits-knownnat,
+ ghc-typelits-natnormalise,
+ template-haskell,
+ vector
+ hs-source-dirs: ops
c-sources: cbits/arith.c
cc-options: -O3 -Wall -Wextra -std=c99
@@ -112,6 +132,7 @@ benchmark bench
main-is: Main.hs
build-depends:
ox-arrays,
+ strided-array-ops,
base,
hmatrix,
orthotope,
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 27ebb64..f7a76bc 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -1,929 +1,23 @@
-{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TemplateHaskell #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Internal.Arith where
-import Control.Monad (forM, guard)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
-import Data.Bifunctor (second)
-import Data.Bits
-import Data.Int
-import Data.List (sort)
-import Data.Vector.Storable qualified as VS
-import Data.Vector.Storable.Mutable qualified as VSM
-import Foreign.C.Types
-import Foreign.Marshal.Alloc (alloca)
-import Foreign.Ptr
-import Foreign.Storable (Storable(sizeOf), peek, poke)
-import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
-import Language.Haskell.TH
-import System.IO (hFlush, stdout)
-import System.IO.Unsafe
-import Data.Array.Mixed.Internal.Arith.Foreign
-import Data.Array.Mixed.Internal.Arith.Lists
-import Data.Array.Mixed.Types (fromSNat')
+import Data.Array.Strided qualified as AS
--- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
+fromO :: RS.Array n a -> AS.Array n a
+fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec
+toO :: AS.Array n a -> RS.Array n a
+toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
--- TODO: test all the cases of this thing with various input strides
-liftVEltwise1 :: (Storable a, Storable b)
- => SNat n
- -> (VS.Vector a -> VS.Vector b)
- -> RS.Array n a -> RS.Array n b
-liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
- | Just (blockOff, blockSz) <- stridesDense sh offset strides =
- let vec' = f (VS.slice blockOff blockSz vec)
- in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
- | otherwise = RS.fromVector sh (f (RS.toVector arr))
+liftO1 :: (AS.Array n a -> AS.Array n' b)
+ -> RS.Array n a -> RS.Array n' b
+liftO1 f = toO . f . fromO
--- TODO: test all the cases of this thing with various input strides
-{-# NOINLINE liftOpEltwise1 #-}
-liftOpEltwise1 :: (Storable a, Storable b)
- => SNat n
- -> (Ptr a -> Ptr a')
- -> (Ptr b -> Ptr b')
- -> (Int64 -> Ptr b' -> Ptr Int64 -> Ptr Int64 -> Ptr a' -> IO ())
- -> RS.Array n a -> RS.Array n b
-liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
- -- TODO: less code duplication between these two branches
- | Just (blockOff, blockSz) <- stridesDense sh offset strides =
- if blockSz == 0
- then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty))
- else unsafePerformIO $ do
- outv <- VSM.unsafeNew blockSz
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->
- VS.unsafeWith (VS.singleton 1) $ \pstrides ->
- VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv ->
- cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
- RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv
- | otherwise = unsafePerformIO $ do
- outv <- VSM.unsafeNew (product sh)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->
- cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
- RS.fromVector sh <$> VS.unsafeFreeze outv
-
--- TODO: test all the cases of this thing with various input strides
-liftVEltwise2 :: Storable a
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (a -> a -> a)
- -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
-liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
- arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
- arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))
- | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
- (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
- let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))
- in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
-
- (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
- let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
- in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec))
-
- (Just (_, 1), Nothing) -> -- scalar * array
- wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2
-
- (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
- let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
- in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec))
-
- (Nothing, Just (_, 1)) -> -- array * scalar
- wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2)
-
- (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
- | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the strides check
- , strides1 == strides2
- -> -- dense * dense but the strides match
- let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
- arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2'
- in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec))
-
- (_, _) -> -- fallback case
- wrapBinaryVV sn ptrconv f_vv arr1 arr2
-
--- | Given shape vector, offset and stride vector, check whether this virtual
--- vector uses a dense subarray of its backing array. If so, the first index
--- and the number of elements in this subarray is returned.
--- This excludes any offset.
-stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
-stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0)
-stridesDense sh offsetNeg stridesNeg =
- -- First reverse all dimensions with negative stride, so that the first used
- -- value is at 'offset' and the rest is >= offset.
- let (offset, strides) = flipReverseds sh offsetNeg stridesNeg
- in -- sort dimensions on their stride, ascending, dropping any zero strides
- case filter ((/= 0) . fst) (sort (zip strides sh)) of
- [] -> Just (offset, 1)
- (1, n) : pairs -> (offset,) <$> checkCover n pairs
- _ -> Nothing -- if the smallest stride is not 1, it will never be dense
- where
- -- Given size of currently densely covered region at beginning of the
- -- array and the remaining (stride, size) pairs with all strides >=1,
- -- return whether this all together covers a dense prefix of the array. If
- -- it does, return the number of elements in this prefix.
- checkCover :: Int -> [(Int, Int)] -> Maybe Int
- checkCover block [] = Just block
- checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs
-
- -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
- flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
- flipReverseds [] off [] = (off, [])
- flipReverseds (n : sh') off (s : str')
- | s >= 0 = second (s :) (flipReverseds sh' off str')
- | otherwise =
- let off' = off + (n - 1) * s
- in second ((-s) :) (flipReverseds sh' off' str')
- flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
-
-{-# NOINLINE wrapBinarySV #-}
-wrapBinarySV :: Storable a
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
- -> a -> RS.Array n a
- -> RS.Array n a
-wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) =
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (product sh)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->
- cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv)
- RS.fromVector sh <$> VS.unsafeFreeze outv
-
-wrapBinaryVS :: Storable a
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
- -> RS.Array n a -> a
- -> RS.Array n a
-wrapBinaryVS sn valconv ptrconv cf_strided arr y =
- wrapBinarySV sn valconv ptrconv
- (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr
-
--- | This function assumes that the two shapes are equal.
-{-# NOINLINE wrapBinaryVV #-}
-wrapBinaryVV :: Storable a
- => SNat n
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
- -> RS.Array n a -> RS.Array n a
- -> RS.Array n a
-wrapBinaryVV sn@SNat ptrconv cf_strided
- (RS.A (RG.A sh (OI.T strides1 offset1 vec1)))
- (RS.A (RG.A _ (OI.T strides2 offset2 vec2))) =
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (product sh)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 ->
- VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 ->
- VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 ->
- VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 ->
- cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2)
- RS.fromVector sh <$> VS.unsafeFreeze outv
-
-{-# NOINLINE vectorOp1 #-}
-vectorOp1 :: forall a b. Storable a
- => (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr b -> IO ())
- -> VS.Vector a -> VS.Vector a
-vectorOp1 ptrconv f v = unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length v)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith v $ \pv ->
- f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv)
- VS.unsafeFreeze outv
-
--- | If two vectors are given, assumes that they have the same length.
-{-# NOINLINE vectorOp2 #-}
-vectorOp2 :: forall a b. Storable a
- => (a -> b)
- -> (Ptr a -> Ptr b)
- -> (a -> a -> a)
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv
- -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs
- -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv
- -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a
-vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
- (Left x) (Left y) -> VS.singleton (fss x y)
-
- (Left x) (Right vy) ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vy)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vy $ \pvy ->
- fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy)
- VS.unsafeFreeze outv
-
- (Right vx) (Left y) ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vx)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vx $ \pvx ->
- fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y)
- VS.unsafeFreeze outv
-
- (Right vx) (Right vy)
- | VS.length vx == VS.length vy ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vx)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vx $ \pvx ->
- VS.unsafeWith vy $ \pvy ->
- fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy)
- VS.unsafeFreeze outv
- | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
-
--- TODO: test handling of negative strides
--- | Reduce along the inner dimension
-{-# NOINLINE vectorRedInnerOp #-}
-vectorRedInnerOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> RS.Array (n + 1) a -> RS.Array n a
-vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = error "unreachable"
- | last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0])
- | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty))
- -- now the input array is nonempty
- | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
- | last strides == 0 =
- liftVEltwise1 sn
- (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
- (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
- -- now there is useful work along the inner dimension
- | otherwise =
- let -- replicated dimensions: dimensions with zero stride. The reduction
- -- kernel need not concern itself with those (and in fact has a
- -- precondition that there are no such dimensions in its input).
- replDims = map (== 0) strides
- -- filter out replicated dimensions
- (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
- -- replace replicated dimensions with ones
- shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims
- ndimsF = length shF -- > 0, otherwise `last strides == 0`
-
- -- reversed dimensions: dimensions with negative stride. Reversal is
- -- irrelevant for a reduction, and indeed the kernel has a
- -- precondition that there are no such dimensions.
- revDims = map (< 0) stridesF
- stridesR = map abs stridesF
- offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
- -- The *R values give an array with strides all > 0, hence the
- -- left-most element is at offsetR.
- in unsafePerformIO $ do
- outvR <- VSM.unsafeNew (product (init shF))
- VSM.unsafeWith outvR $ \poutvR ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
- VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)
- TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
- RS.stretch (init sh) -- replicate to original shape
- . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated
- . RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions
- . RS.fromVector @_ @lenFm1 (init shF) -- the partially-reversed result array
- <$> VS.unsafeFreeze outvR
-
--- TODO: test handling of negative strides
--- | Reduce full array
-{-# NOINLINE vectorRedFullOp #-}
-vectorRedFullOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> Int -> a)
- -> (b -> a)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
- -> RS.Array n a -> a
-vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = vec VS.! offset -- 0D array has one element
- | any (<= 0) sh = 0
- -- now the input array is nonempty
- | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset
- -- now there is at least one non-replicated dimension
- | otherwise =
- let -- replicated dimensions: dimensions with zero stride. The reduction
- -- kernel need not concern itself with those (and in fact has a
- -- precondition that there are no such dimensions in its input).
- replDims = map (== 0) strides
- -- filter out replicated dimensions
- (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
- ndimsF = length shF -- > 0, otherwise `all (== 0) strides`
- -- we should scale up the output this many times to account for the replicated dimensions
- multiplier = product [n | (n, True) <- zip sh replDims]
-
- -- reversed dimensions: dimensions with negative stride. Reversal is
- -- irrelevant for a reduction, and indeed the kernel has a
- -- precondition that there are no such dimensions.
- revDims = map (< 0) stridesF
- stridesR = map abs stridesF
- offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
- -- The *R values give an array with strides all > 0, hence the
- -- left-most element is at offsetR.
- in unsafePerformIO $ do
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
- VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- (`scaleval` multiplier) . valbackconv
- <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)
-
--- TODO: test this function
--- | Find extremum (minindex ("argmin") or maxindex) in full array
-{-# NOINLINE vectorExtremumOp #-}
-vectorExtremumOp :: forall a b n. Storable a
- => (Ptr a -> Ptr b)
- -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
- -> RS.Array n a -> [Int] -- result length: n
-vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = []
- | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
- -- now the input array is nonempty
- | all (== 0) strides = 0 <$ sh
- -- now there is at least one non-replicated dimension
- | otherwise =
- let -- replicated dimensions: dimensions with zero stride. The extremum
- -- kernel need not concern itself with those (and in fact has a
- -- precondition that there are no such dimensions in its input).
- replDims = map (== 0) strides
- -- filter out replicated dimensions
- (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
- ndimsF = length shF -- > 0, because not all strides were <=0
-
- -- un-reverse reversed dimensions
- revDims = map (< 0) stridesF
- stridesR = map abs stridesF
- offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
-
- -- function to insert zeros in replicated-out dimensions
- insertZeros :: [Bool] -> [Int] -> [Int]
- insertZeros [] idx = idx
- insertZeros (True : repls) idx = 0 : insertZeros repls idx
- insertZeros (False : repls) (i : idx) = i : insertZeros repls idx
- insertZeros (_:_) [] = error "unreachable"
- in unsafePerformIO $ do
- outvR <- VSM.unsafeNew (length shF)
- VSM.unsafeWith outvR $ \poutvR ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
- VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)
- insertZeros replDims
- . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF -- re-reverse the reversed dimensions
- . map (fromIntegral @Int64 @Int)
- . VS.toList
- <$> VS.unsafeFreeze outvR
-
-vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
- -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
-vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
- arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
- arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | null sh1 || null sh2 = error "unreachable"
- | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0])
- | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty))
- -- now the input arrays are nonempty
- | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2)
- | last strides1 == 0 =
- fmul sn
- (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1)))
- (vectorRedInnerOp sn valconv ptrconv fscale fred arr2)
- | last strides2 == 0 =
- fmul sn
- (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
- (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2)))
- -- now there is useful dotprod work along the inner dimension
- | otherwise = unsafePerformIO $ do
- let inrank = fromSNat' sn + 1
- outv <- VSM.unsafeNew (product (init sh1))
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
- pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
- pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
- RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
-
-{-# NOINLINE dotScalarVector #-}
-dotScalarVector :: forall a b. (Num a, Storable a)
- => Int -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> a -> VS.Vector a -> a
-dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
- alloca @a $ \pout -> do
- alloca @Int64 $ \pshape -> do
- poke pshape (fromIntegral @Int @Int64 len)
- alloca @Int64 $ \pstride -> do
- poke pstride 1
- VS.unsafeWith vec $ \pvec ->
- fred 1 (ptrconv pout) pshape pstride (ptrconv pvec)
- res <- peek pout
- return (scalar * res)
-
-{-# NOINLINE dotVectorVector #-}
-dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
- -> VS.Vector a -> VS.Vector a -> a
-dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2)
-
-{-# NOINLINE dotVectorVectorStrided #-}
-dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
- -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel
- -> Int -> Int -> VS.Vector a
- -> Int -> Int -> VS.Vector a
- -> a
-dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- valbackconv <$> fdot (fromIntegral @Int @Int64 len)
- (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1)
- (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2)
-
-flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
- -> Int64 -> Ptr a -> Ptr a -> a -> IO ()
-flipOp f n out v s = f n out s v
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_binary_" ++ atCName arithtype
- c_ss_str = varE (aboNumOp arithop)
- c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM intTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_ibinary_" ++ atCName arithtype
- c_ss_str = varE (aiboNumOp arithop)
- c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
- c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
- c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_fbinary_" ++ atCName arithtype
- c_ss_str = varE (afboNumOp arithop)
- c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-mulWithInt :: Num a => a -> Int -> a
-mulWithInt a i = a * fromIntegral i
-
-scaleFromSVStrided :: (Int64 -> Ptr Int64 -> Ptr a -> a -> Ptr Int64 -> Ptr a -> IO ())
- -> Int64 -> Ptr a -> a -> Ptr a -> IO ()
-scaleFromSVStrided fsv n out x ys =
- VS.unsafeWith (VS.singleton n) $ \psh ->
- VS.unsafeWith (VS.singleton 1) $ \pstrides ->
- fsv 1 psh out x pstrides ys
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let scaleVar = case arithop of
- RO_SUM -> varE 'mulWithInt
- RO_PRODUCT -> varE '(^)
- let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
- namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
- c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
- c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
- c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
- sequence [SigD name1 <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> vectorRedInnerOp sn id id (scaleFromSVStrided $c_scale_op) $c_op1 |]
- return $ FunD name1 [Clause [] (NormalB body) []]
- ,SigD namefull <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |]
- ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
- return $ FunD namefull [Clause [] (NormalB body) []]
- ])
-
-$(fmap concat . forM typesList $ \arithtype ->
- fmap concat . forM ["min", "max"] $ \fname -> do
- let ttyp = conT (atType arithtype)
- name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
- sequence [SigD name <$>
- [t| forall n. RS.Array n $ttyp -> [Int] |]
- ,do body <- [| vectorExtremumOp id $c_op |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
- mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
- c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
- c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
-foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()
-
-statisticsEnable :: Bool -> IO ()
-statisticsEnable b = c_stats_enable (if b then 1 else 0)
-
--- | Consumes the log: one particular event will only ever be printed once,
--- even if statisticsPrintAll is called multiple times.
-statisticsPrintAll :: IO ()
-statisticsPrintAll = do
- hFlush stdout -- lower the chance of overlapping output
- c_stats_print_all
-
--- This branch is ostensibly a runtime branch, but will (hopefully) be
--- constant-folded away by GHC.
-intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
- => (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
- -> (SNat n -> RS.Array n i -> RS.Array n i)
-intWidBranch1 f32 f64 sn
- | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr castPtr f32
- | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr castPtr f64
- | otherwise = error "Unsupported Int width"
-
-intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
- => (i -> i -> i) -- ss
- -- int32
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- sv
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ()) -- vs
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- vv
- -- int64
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- sv
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
- -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
-intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
- | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32
- | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
- => -- int32
- (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
-intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
- | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
- => (i -> Int -> i) -- ^ scale op
- -- int32
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32) -- ^ reduction kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ reduction kernel
- -> (SNat n -> RS.Array n i -> i)
-intWidBranchRedFull fsc fred32 fred64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
- | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
- => -- int32
- (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ extremum kernel
- -- int64
- -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ extremum kernel
- -> (RS.Array n i -> [Int])
-intWidBranchExtr fextr32 fextr64
- | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
- | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
- | otherwise = error "Unsupported Int width"
-
-intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
- => -- int32
- (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ dotprod kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i)
-intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
- | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
- | otherwise = error "Unsupported Int width"
-
-class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltSumFull :: SNat n -> RS.Array n a -> a
- numEltProductFull :: SNat n -> RS.Array n a -> a
- numEltMinIndex :: SNat n -> RS.Array n a -> [Int]
- numEltMaxIndex :: SNat n -> RS.Array n a -> [Int]
- numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
-
-instance NumElt Int32 where
- numEltAdd = addVectorInt32
- numEltSub = subVectorInt32
- numEltMul = mulVectorInt32
- numEltNeg = negVectorInt32
- numEltAbs = absVectorInt32
- numEltSignum = signumVectorInt32
- numEltSum1Inner = sum1VectorInt32
- numEltProduct1Inner = product1VectorInt32
- numEltSumFull = sumFullVectorInt32
- numEltProductFull = productFullVectorInt32
- numEltMinIndex _ = minindexVectorInt32
- numEltMaxIndex _ = maxindexVectorInt32
- numEltDotprodInner = dotprodinnerVectorInt32
-
-instance NumElt Int64 where
- numEltAdd = addVectorInt64
- numEltSub = subVectorInt64
- numEltMul = mulVectorInt64
- numEltNeg = negVectorInt64
- numEltAbs = absVectorInt64
- numEltSignum = signumVectorInt64
- numEltSum1Inner = sum1VectorInt64
- numEltProduct1Inner = product1VectorInt64
- numEltSumFull = sumFullVectorInt64
- numEltProductFull = productFullVectorInt64
- numEltMinIndex _ = minindexVectorInt64
- numEltMaxIndex _ = maxindexVectorInt64
- numEltDotprodInner = dotprodinnerVectorInt64
-
-instance NumElt Float where
- numEltAdd = addVectorFloat
- numEltSub = subVectorFloat
- numEltMul = mulVectorFloat
- numEltNeg = negVectorFloat
- numEltAbs = absVectorFloat
- numEltSignum = signumVectorFloat
- numEltSum1Inner = sum1VectorFloat
- numEltProduct1Inner = product1VectorFloat
- numEltSumFull = sumFullVectorFloat
- numEltProductFull = productFullVectorFloat
- numEltMinIndex _ = minindexVectorFloat
- numEltMaxIndex _ = maxindexVectorFloat
- numEltDotprodInner = dotprodinnerVectorFloat
-
-instance NumElt Double where
- numEltAdd = addVectorDouble
- numEltSub = subVectorDouble
- numEltMul = mulVectorDouble
- numEltNeg = negVectorDouble
- numEltAbs = absVectorDouble
- numEltSignum = signumVectorDouble
- numEltSum1Inner = sum1VectorDouble
- numEltProduct1Inner = product1VectorDouble
- numEltSumFull = sumFullVectorDouble
- numEltProductFull = productFullVectorDouble
- numEltMinIndex _ = minindexVectorDouble
- numEltMaxIndex _ = maxindexVectorDouble
- numEltDotprodInner = dotprodinnerVectorDouble
-
-instance NumElt Int where
- numEltAdd = intWidBranch2 @Int (+)
- (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
- (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
- numEltSub = intWidBranch2 @Int (-)
- (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
- (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
- numEltMul = intWidBranch2 @Int (*)
- (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
- (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed1 @Int
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM))
- numEltProduct1Inner = intWidBranchRed1 @Int
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT))
- numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
- numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
- numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
- numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
- numEltDotprodInner = intWidBranchDotprod @Int (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
-
-instance NumElt CInt where
- numEltAdd = intWidBranch2 @CInt (+)
- (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
- (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
- numEltSub = intWidBranch2 @CInt (-)
- (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
- (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
- numEltMul = intWidBranch2 @CInt (*)
- (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
- (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed1 @CInt
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM))
- numEltProduct1Inner = intWidBranchRed1 @CInt
- (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT))
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT))
- numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
- numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
- numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
- numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
- numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
- (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
-
-class NumElt a => IntElt a where
- intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
-
-instance IntElt Int32 where
- intEltQuot = quotVectorInt32
- intEltRem = remVectorInt32
-
-instance IntElt Int64 where
- intEltQuot = quotVectorInt64
- intEltRem = remVectorInt64
-
-instance IntElt Int where
- intEltQuot = intWidBranch2 @Int quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
- intEltRem = intWidBranch2 @Int rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
-
-instance IntElt CInt where
- intEltQuot = intWidBranch2 @CInt quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
- intEltRem = intWidBranch2 @CInt rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
-
-class NumElt a => FloatElt a where
- floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
-
-instance FloatElt Float where
- floatEltDiv = divVectorFloat
- floatEltPow = powVectorFloat
- floatEltLogbase = logbaseVectorFloat
- floatEltRecip = recipVectorFloat
- floatEltExp = expVectorFloat
- floatEltLog = logVectorFloat
- floatEltSqrt = sqrtVectorFloat
- floatEltSin = sinVectorFloat
- floatEltCos = cosVectorFloat
- floatEltTan = tanVectorFloat
- floatEltAsin = asinVectorFloat
- floatEltAcos = acosVectorFloat
- floatEltAtan = atanVectorFloat
- floatEltSinh = sinhVectorFloat
- floatEltCosh = coshVectorFloat
- floatEltTanh = tanhVectorFloat
- floatEltAsinh = asinhVectorFloat
- floatEltAcosh = acoshVectorFloat
- floatEltAtanh = atanhVectorFloat
- floatEltLog1p = log1pVectorFloat
- floatEltExpm1 = expm1VectorFloat
- floatEltLog1pexp = log1pexpVectorFloat
- floatEltLog1mexp = log1mexpVectorFloat
- floatEltAtan2 = atan2VectorFloat
-
-instance FloatElt Double where
- floatEltDiv = divVectorDouble
- floatEltPow = powVectorDouble
- floatEltLogbase = logbaseVectorDouble
- floatEltRecip = recipVectorDouble
- floatEltExp = expVectorDouble
- floatEltLog = logVectorDouble
- floatEltSqrt = sqrtVectorDouble
- floatEltSin = sinVectorDouble
- floatEltCos = cosVectorDouble
- floatEltTan = tanVectorDouble
- floatEltAsin = asinVectorDouble
- floatEltAcos = acosVectorDouble
- floatEltAtan = atanVectorDouble
- floatEltSinh = sinhVectorDouble
- floatEltCosh = coshVectorDouble
- floatEltTanh = tanhVectorDouble
- floatEltAsinh = asinhVectorDouble
- floatEltAcosh = acoshVectorDouble
- floatEltAtanh = atanhVectorDouble
- floatEltLog1p = log1pVectorDouble
- floatEltExpm1 = expm1VectorDouble
- floatEltLog1pexp = log1pexpVectorDouble
- floatEltLog1mexp = log1mexpVectorDouble
- floatEltAtan2 = atan2VectorDouble
+liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
+ -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
+liftO2 f x y = toO (f (fromO x) (fromO y))
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index 71bdc1f..204c1d8 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -34,6 +34,7 @@ import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
+import Data.Array.Strided.Arith
type XArray :: [Maybe Nat] -> Type -> Type
@@ -240,7 +241,7 @@ transpose2 ssh1 ssh2 (XArray arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
S.unScalar $
- numEltSum1Inner (SNat @0) $
+ liftO1 (numEltSum1Inner (SNat @0)) $
S.fromVector [product (S.shapeL arr)] $
S.toVector arr
@@ -256,7 +257,7 @@ sumInner ssh ssh' arr
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
, let sn = listxRank (let StaticShX l = ssh in l)
- = XArray (numEltSum1Inner sn arr')
+ = XArray (liftO1 (numEltSum1Inner sn) arr')
in go $
transpose2 ssh'F ssh $
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index db13da4..9869cba 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -103,7 +103,6 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
-import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
@@ -112,6 +111,7 @@ import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped
+import Data.Array.Strided.Arith
import Foreign.Storable
import GHC.TypeLits
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 80d581e..eb452dd 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -49,6 +49,7 @@ import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Lemmas
+import Data.Array.Strided.Arith
-- TODO:
-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
@@ -225,52 +226,52 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_
| otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
- (+) = mliftNumElt2 numEltAdd
- (-) = mliftNumElt2 numEltSub
- (*) = mliftNumElt2 numEltMul
- negate = mliftNumElt1 numEltNeg
- abs = mliftNumElt1 numEltAbs
- signum = mliftNumElt1 numEltSignum
+ (+) = mliftNumElt2 (liftO2 . numEltAdd)
+ (-) = mliftNumElt2 (liftO2 . numEltSub)
+ (*) = mliftNumElt2 (liftO2 . numEltMul)
+ negate = mliftNumElt1 (liftO1 . numEltNeg)
+ abs = mliftNumElt1 (liftO1 . numEltAbs)
+ signum = mliftNumElt1 (liftO1 . numEltSignum)
-- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
- recip = mliftNumElt1 floatEltRecip
- (/) = mliftNumElt2 floatEltDiv
+ recip = mliftNumElt1 (liftO1 . floatEltRecip)
+ (/) = mliftNumElt2 (liftO2 . floatEltDiv)
instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
- exp = mliftNumElt1 floatEltExp
- log = mliftNumElt1 floatEltLog
- sqrt = mliftNumElt1 floatEltSqrt
-
- (**) = mliftNumElt2 floatEltPow
- logBase = mliftNumElt2 floatEltLogbase
-
- sin = mliftNumElt1 floatEltSin
- cos = mliftNumElt1 floatEltCos
- tan = mliftNumElt1 floatEltTan
- asin = mliftNumElt1 floatEltAsin
- acos = mliftNumElt1 floatEltAcos
- atan = mliftNumElt1 floatEltAtan
- sinh = mliftNumElt1 floatEltSinh
- cosh = mliftNumElt1 floatEltCosh
- tanh = mliftNumElt1 floatEltTanh
- asinh = mliftNumElt1 floatEltAsinh
- acosh = mliftNumElt1 floatEltAcosh
- atanh = mliftNumElt1 floatEltAtanh
- log1p = mliftNumElt1 floatEltLog1p
- expm1 = mliftNumElt1 floatEltExpm1
- log1pexp = mliftNumElt1 floatEltLog1pexp
- log1mexp = mliftNumElt1 floatEltLog1mexp
+ exp = mliftNumElt1 (liftO1 . floatEltExp)
+ log = mliftNumElt1 (liftO1 . floatEltLog)
+ sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
+
+ (**) = mliftNumElt2 (liftO2 . floatEltPow)
+ logBase = mliftNumElt2 (liftO2 . floatEltLogbase)
+
+ sin = mliftNumElt1 (liftO1 . floatEltSin)
+ cos = mliftNumElt1 (liftO1 . floatEltCos)
+ tan = mliftNumElt1 (liftO1 . floatEltTan)
+ asin = mliftNumElt1 (liftO1 . floatEltAsin)
+ acos = mliftNumElt1 (liftO1 . floatEltAcos)
+ atan = mliftNumElt1 (liftO1 . floatEltAtan)
+ sinh = mliftNumElt1 (liftO1 . floatEltSinh)
+ cosh = mliftNumElt1 (liftO1 . floatEltCosh)
+ tanh = mliftNumElt1 (liftO1 . floatEltTanh)
+ asinh = mliftNumElt1 (liftO1 . floatEltAsinh)
+ acosh = mliftNumElt1 (liftO1 . floatEltAcosh)
+ atanh = mliftNumElt1 (liftO1 . floatEltAtanh)
+ log1p = mliftNumElt1 (liftO1 . floatEltLog1p)
+ expm1 = mliftNumElt1 (liftO1 . floatEltExpm1)
+ log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp)
+ log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)
mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
-mquotArray = mliftNumElt2 intEltQuot
-mremArray = mliftNumElt2 intEltRem
+mquotArray = mliftNumElt2 (liftO2 . intEltQuot)
+mremArray = mliftNumElt2 (liftO2 . intEltRem)
matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
-matan2Array = mliftNumElt2 floatEltAtan2
+matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
@@ -867,12 +868,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-- | Throws if the array is empty.
mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr)
+ ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
-- | Throws if the array is empty.
mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr)
+ ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
@@ -883,7 +884,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
_ :$% _
| sh1 == sh2
, Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
- fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b))
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
| otherwise -> error "mdot1Inner: Unequal shapes"
ZSX -> error "unreachable"
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 1c6b789..0a165bc 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -41,13 +41,13 @@ import GHC.TypeNats qualified as TN
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Shape
+import Data.Array.Strided.Arith
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 35628db..d7a8ece 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -41,7 +41,6 @@ import GHC.TypeLits
import Data.Array.Mixed.XArray (XArray)
import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
@@ -49,6 +48,7 @@ import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Shape
+import Data.Array.Strided.Arith
-- | A shape-typed array: the full shape of the array (the sizes of its