diff options
| -rw-r--r-- | bench/Main.hs | 2 | ||||
| -rw-r--r-- | cbits/arith.c | 40 | ||||
| -rw-r--r-- | ops/Data/Array/Strided.hs | 7 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith.hs | 7 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 866 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Foreign.hs (renamed from src/Data/Array/Mixed/Internal/Arith/Foreign.hs) | 4 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Lists.hs (renamed from src/Data/Array/Mixed/Internal/Arith/Lists.hs) | 4 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs (renamed from src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs) | 2 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Array.hs | 42 | ||||
| -rw-r--r-- | ox-arrays.cabal | 27 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 928 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 5 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 71 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 2 | 
16 files changed, 1026 insertions, 985 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 6e83270..3ab81a8 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -12,7 +12,7 @@ import Test.Tasty.Bench  import Data.Array.Nested  import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2)  import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2) -import qualified Data.Array.Mixed.Internal.Arith as Arith +import qualified Data.Array.Strided.Arith.Internal as Arith  enableMisc :: Bool diff --git a/cbits/arith.c b/cbits/arith.c index ca0af51..b574d54 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -24,6 +24,17 @@ typedef int32_t i32;  typedef int64_t i64; +// PRECONDITIONS +// +// All strided array operations in this file assume that none of the shape +// components are zero -- that is, the input arrays are non-empty. This must +// be arranged on the Haskell side. +// +// Furthermore, note that while the Haskell side has an offset into the backing +// vector, the C side assumes that the offset is zero. Shift the pointer if +// necessary. + +  /*****************************************************************************   *                          Performance statistics                           *   *****************************************************************************/ @@ -370,6 +381,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  #define COMM_OP_STRIDED(name, op, typ) \    static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ +    if (rank == 0) { out[0] = x op y[0]; return; } \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \ @@ -377,6 +389,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      }); \    } \    static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +    if (rank == 0) { out[0] = x[0] op y[0]; return; } \      TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \ @@ -387,6 +400,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  #define NONCOMM_OP_STRIDED(name, op, typ) \    COMM_OP_STRIDED(name, op, typ) \    static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ +    if (rank == 0) { out[0] = x[0] op y; return; } \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \ @@ -396,6 +410,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  #define PREFIX_BINOP_STRIDED(name, op, typ) \    static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ +    if (rank == 0) { out[0] = op(x, y[0]); return; } \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \ @@ -403,6 +418,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      }); \    } \    static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +    if (rank == 0) { out[0] = op(x[0], y[0]); return; } \      TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \ @@ -410,6 +426,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      }); \    } \    static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ +    if (rank == 0) { out[0] = op(x[0], y); return; } \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \ @@ -424,6 +441,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      fprintf(stderr, " strides="); \      print_shape(stderr, rank, strides); \      fprintf(stderr, "\n"); */ \ +    if (rank == 0) { out[0] = op(arr[0]); return; } \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \ @@ -434,7 +452,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  // Used for reduction and dot product kernels below  #define MANUAL_VECT_WID 8 -// Used in REDUCE1_OP and REDUCEFULL_OP below; requires the same preconditions +// Used in REDUCE1_OP and REDUCEFULL_OP below  #define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \    do { \      const i64 n = innerLen; const i64 s = innerStride; \ @@ -458,11 +476,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      } \    } while (0) -// preconditions: -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1 -// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define REDUCE1_OP(name, op, typ) \ @@ -472,12 +485,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      }); \    } -// preconditions -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1  #define REDUCEFULL_OP(name, op, typ) \    typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ +    if (rank == 0) return arr[0]; \      typ result = 0; \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ @@ -485,13 +495,10 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      return result; \    } -// preconditions -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1  // Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.  #define EXTREMUM_OP(name, cmp, typ) \    void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ +    if (rank == 0) return; /* output index vector has length 0 anyways */ \      typ best = arr[0]; \      memset(outidx, 0, rank * sizeof(i64)); \      TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ @@ -527,11 +534,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      } \    } -// preconditions: -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1 -// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define DOTPROD_INNER_OP(typ) \ diff --git a/ops/Data/Array/Strided.hs b/ops/Data/Array/Strided.hs new file mode 100644 index 0000000..a0506a9 --- /dev/null +++ b/ops/Data/Array/Strided.hs @@ -0,0 +1,7 @@ +module Data.Array.Strided ( +  module Data.Array.Strided.Array, +  module Data.Array.Strided.Arith, +) where + +import Data.Array.Strided.Array +import Data.Array.Strided.Arith diff --git a/ops/Data/Array/Strided/Arith.hs b/ops/Data/Array/Strided/Arith.hs new file mode 100644 index 0000000..7be6390 --- /dev/null +++ b/ops/Data/Array/Strided/Arith.hs @@ -0,0 +1,7 @@ +module Data.Array.Strided.Arith ( +  NumElt(..), +  IntElt(..), +  FloatElt(..), +) where + +import Data.Array.Strided.Arith.Internal diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs new file mode 100644 index 0000000..fe0fc4b --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -0,0 +1,866 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Strided.Arith.Internal where + +import Control.Monad +import Data.Bifunctor (second) +import Data.Bits +import Data.Int +import Data.List (sort) +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable +import qualified GHC.TypeNats as TypeNats +import GHC.TypeLits +import Language.Haskell.TH +import System.IO (hFlush, stdout) +import System.IO.Unsafe + +import Data.Array.Strided.Array +import Data.Array.Strided.Arith.Internal.Lists +import Data.Array.Strided.Arith.Internal.Foreign + + +-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition + + +-- TODO: move this to a utilities module +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +data Dict c where +  Dict :: c => Dict c + +debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String +debugShow (Array sh strides offset vec) = +  "Array @" ++ (show (natVal (Proxy @n))) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">" + + +-- TODO: test all the cases of this thing with various input strides +liftOpEltwise1 :: (Storable a, Storable b) +               => SNat n +               -> (Ptr a -> Ptr b) +               -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) +               -> Array n a -> Array n a +liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      if blockSz == 0 +        then Array sh (map (const 0) strides) 0 VS.empty +        else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [fromIntegral blockSz] [1] blockOff vec) +             in Array sh strides (offset - blockOff) resvec +  | otherwise = wrapUnary sn ptrconv cf_strided arr + +-- TODO: test all the cases of this thing with various input strides +liftOpEltwise2 :: Storable a +               => SNat n +               -> (a -> b) +               -> (Ptr a -> Ptr b) +               -> (a -> a -> a) +               -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ sv +               -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- ^ vs +               -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ vv +               -> Array n a -> Array n a -> Array n a +liftOpEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv +    arr1@(Array sh1 strides1 offset1 vec1) +    arr2@(Array sh2 strides2 offset2 vec2) +  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | any (<= 0) sh1 = Array sh1 (0 <$ strides1) 0 VS.empty +  | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of +      (Just (_, 1), Just (_, 1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars +        let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2)) +        in Array sh1 strides1 0 vec' + +      (Just (_, 1), Just (blockOff, blockSz)) ->  -- scalar * dense +        let arr2' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec2) +            resvec = arrValues $ wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2' +        in Array sh1 strides2 (offset2 - blockOff) resvec + +      (Just (_, 1), Nothing) ->  -- scalar * array +        wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2 + +      (Just (blockOff, blockSz), Just (_, 1)) ->  -- dense * scalar +        let arr1' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec1) +            resvec = arrValues $ wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2) +        in Array sh1 strides1 (offset1 - blockOff) resvec + +      (Nothing, Just (_, 1)) ->  -- array * scalar +        wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2) + +      (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) +        | strides1 == strides2 +        ->  -- dense * dense but the strides match +          if blockSz1 /= blockSz2 || offset1 - blockOff1 /= offset2 - blockOff2 +            then error $ "Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " ++ show (strides1, (blockOff1, blockSz1), strides2, (blockOff2, blockSz2)) +            else +              let arr1' = arrayFromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) +                  arr2' = arrayFromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) +                  resvec = arrValues $ wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2' +              in Array sh1 strides1 (offset1 - blockOff1) resvec + +      (_, _) ->  -- fallback case +        wrapBinaryVV sn ptrconv f_vv arr1 arr2 + +-- | Given shape vector, offset and stride vector, check whether this virtual +-- vector uses a dense subarray of its backing array. If so, the first index +-- and the number of elements in this subarray is returned. +-- This excludes any offset. +stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int) +stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0) +stridesDense sh offsetNeg stridesNeg = +  -- First reverse all dimensions with negative stride, so that the first used +  -- value is at 'offset' and the rest is >= offset. +  let (offset, strides) = flipReverseds sh offsetNeg stridesNeg +  in -- sort dimensions on their stride, ascending, dropping any zero strides +     case filter ((/= 0) . fst) (sort (zip strides sh)) of +       [] -> Just (offset, 1) +       (1, n) : pairs -> (offset,) <$> checkCover n pairs +       _ -> Nothing  -- if the smallest stride is not 1, it will never be dense +  where +    -- Given size of currently densely covered region at beginning of the +    -- array and the remaining (stride, size) pairs with all strides >=1, +    -- return whether this all together covers a dense prefix of the array. If +    -- it does, return the number of elements in this prefix. +    checkCover :: Int -> [(Int, Int)] -> Maybe Int +    checkCover block [] = Just block +    checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs + +    -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0 +    flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int]) +    flipReverseds [] off [] = (off, []) +    flipReverseds (n : sh') off (s : str') +      | s >= 0 = second (s :) (flipReverseds sh' off str') +      | otherwise = +          let off' = off + (n - 1) * s +          in second ((-s) :) (flipReverseds sh' off' str') +    flipReverseds _ _ _ = error "flipReverseds: invalid arguments" + +data Unreplicated a = +  forall n'. KnownNat n' => +    -- | Let the original array, with replicated dimensions, be called A. +    Unreplicated -- | An array with all strides /= 0. Call this array U. It has +                 -- the same shape as A, except with all the replicated (stride +                 -- == 0) dimensions removed. The shape of U is the +                 -- "unreplicated shape". +                 (Array n' a) +                 -- | Product of sizes of the unreplicated dimensions +                 Int +                 -- | Given the stride vector of an array with the unreplicated +                 -- shape, this function reinserts zeros so that it may be +                 -- combined with the original shape of A. +                 ([Int] -> [Int]) + +-- | Removes all replicated dimensions (i.e. those with stride == 0) from the array. +unreplicateStrides :: Array n a -> Unreplicated a +unreplicateStrides (Array sh strides offset vec) = +  let replDims = map (== 0) strides +      (shF, stridesF) = unzip [(n, s) | (n, s) <- zip sh strides, s /= 0] + +      reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' +      reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' +      reinsertZeros [] [] = [] +      reinsertZeros (False : _) [] = error $ "unreplicateStrides: Internal error: reply strides too short" +      reinsertZeros [] (_:_) = error $ "unreplicateStrides: Internal error: reply strides too long" + +      unrepSize = product [n | (n, True) <- zip sh replDims] + +  in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> +       Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims) + +simplifyArray :: Array n a +              -> (forall n'. KnownNat n' +              => Array n' a  -- U +                          -- Product of sizes of the unreplicated dimensions +                          -> Int +                          -- Convert index in U back to index into original +                          -- array. Replicated dimensions get 0. +                          -> ([Int] -> [Int]) +                          -- Given a new array of the same shape as U, convert +                          -- it back to the original shape and iteration order. +                          -> (Array n' a -> Array n a) +                          -- Do the same except without the INNER dimension. +                          -- This throws an error if the inner dimension had +                          -- stride 0. +                          -> (Array (n' - 1) a -> Array (n - 1) a) +                          -> r) +              -> r +simplifyArray array k +  | let revDims = map (<0) (arrStrides array) +  , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array) +  = k array' +      unrepSize +      (\idx -> rereplicate (zipWith3 (\b n i -> if b then n - 1 - i else i) +                                     revDims (arrShape array') idx)) +      (\(Array sh' strides' offset' vec') -> +         if sh' == arrShape array' +           then arrayRevDims revDims (Array (arrShape array) (rereplicate strides') offset' vec') +           else error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")") +      (\(Array sh' strides' offset' vec') -> +         if | sh' /= init (arrShape array') -> +                error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")" +            | last (arrStrides array) == 0 -> +                error $ "simplifyArray: Internal error: reduction reply handler used while inner stride was 0" +            | otherwise -> +                arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec')) + +{-# NOINLINE wrapUnary #-} +wrapUnary :: forall a b n. Storable a +          => SNat n +          -> (Ptr a -> Ptr b) +          -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) +          -> Array n a +          -> Array n a +wrapUnary _ ptrconv cf_strided array = +  simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do +    let ndims' = length sh +    outv <- VSM.unsafeNew (product sh) +    VSM.unsafeWith outv $ \poutv -> +      VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh -> +      VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides -> +      VS.unsafeWith vec $ \pv -> +        let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a)) +        in cf_strided (fromIntegral ndims') (ptrconv poutv) psh pstrides pv' +    restore . arrayFromVector sh <$> VS.unsafeFreeze outv + +{-# NOINLINE wrapBinarySV #-} +wrapBinarySV :: forall a b n. Storable a +             => SNat n +             -> (a -> b) +             -> (Ptr a -> Ptr b) +             -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) +             -> a -> Array n a +             -> Array n a +wrapBinarySV SNat valconv ptrconv cf_strided x array = +  simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do +    let ndims' = length sh +    outv <- VSM.unsafeNew (product sh) +    VSM.unsafeWith outv $ \poutv -> +      VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh -> +      VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides -> +      VS.unsafeWith vec $ \pv -> +        let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a)) +        in cf_strided (fromIntegral ndims') psh (ptrconv poutv) (valconv x) pstrides pv' +    restore . arrayFromVector sh <$> VS.unsafeFreeze outv + +wrapBinaryVS :: Storable a +             => SNat n +             -> (a -> b) +             -> (Ptr a -> Ptr b) +             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) +             -> Array n a -> a +             -> Array n a +wrapBinaryVS sn valconv ptrconv cf_strided arr y = +  wrapBinarySV sn valconv ptrconv +               (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr + +-- | The two shapes must be equal and non-empty. This is checked. +{-# NOINLINE wrapBinaryVV #-} +wrapBinaryVV :: forall a b n. Storable a +             => SNat n +             -> (Ptr a -> Ptr b) +             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) +             -> Array n a -> Array n a +             -> Array n a +-- TODO: do unreversing and unreplication on the input arrays (but +-- simultaneously: can only unreplicate if _both_ are replicated on that +-- dimension) +wrapBinaryVV sn@SNat ptrconv cf_strided +    (Array sh strides1 offset1 vec1) +    (Array sh2 strides2 offset2 vec2) +  | sh /= sh2 = error $ "wrapBinaryVV: unequal shapes: " ++ show sh ++ " and " ++ show sh2 +  | any (<= 0) sh = error $ "wrapBinaryVV: empty shape: " ++ show sh +  | otherwise = unsafePerformIO $ do +      outv <- VSM.unsafeNew (product sh) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> +        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 -> +        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 -> +        VS.unsafeWith vec1 $ \pv1 -> +        VS.unsafeWith vec2 $ \pv2 -> +          let pv1' = pv1 `plusPtr` (offset1 * sizeOf (undefined :: a)) +              pv2' = pv2 `plusPtr` (offset2 * sizeOf (undefined :: a)) +          in cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 pv1' pstrides2 pv2' +      arrayFromVector sh <$> VS.unsafeFreeze outv + +-- TODO: test handling of negative strides +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) +                 => SNat n +                 -> (a -> b) +                 -> (Ptr a -> Ptr b) +                 -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> Array (n + 1) a -> Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides offset vec) +  | null sh = error "unreachable" +  | last sh <= 0 = arrayFromConstant (init sh) 0 +  | any (<= 0) (init sh) = Array (init sh) (0 <$ init strides) 0 VS.empty +  -- now the input array is nonempty +  | last sh == 1 = Array (init sh) (init strides) offset vec +  | last strides == 0 = +      wrapBinarySV sn valconv ptrconv fscale (fromIntegral @Int @a (last sh)) +                   (Array (init sh) (init strides) offset vec) +  -- now there is useful work along the inner dimension +  -- Note that unreplication keeps the inner dimension intact, because `last strides /= 0` at this point. +  | otherwise = +      simplifyArray array $ \(Array sh' strides' offset' vec' :: Array n' a) _ _ _ restore -> unsafePerformIO $ do +        let ndims' = length sh' +        outv <- VSM.unsafeNew (product (init sh')) +        VSM.unsafeWith outv $ \poutv -> +          VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh -> +          VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides -> +          VS.unsafeWith vec' $ \pv -> +            let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) +            in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv') +        TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do +          (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of +                                        LTI -> pure Dict +                                        EQI -> pure Dict +                                        _ -> error "impossible"  -- because `last strides /= 0` +          case sameNat (natSing @(n' - 1)) (natSing @n'm1) of +            Just Refl -> restore . arrayFromVector @_ @n'm1 (init sh') <$> VS.unsafeFreeze outv +            Nothing -> error "impossible" + +-- TODO: test handling of negative strides +-- | Reduce full array +{-# NOINLINE vectorRedFullOp #-} +vectorRedFullOp :: forall a b n. (Num a, Storable a) +                => SNat n +                -> (a -> Int -> a) +                -> (b -> a) +                -> (Ptr a -> Ptr b) +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel +                -> Array n a -> a +vectorRedFullOp _ scaleval valbackconv ptrconv fred array@(Array sh strides offset vec) +  | null sh = vec VS.! offset  -- 0D array has one element +  | any (<= 0) sh = 0 +  -- now the input array is nonempty +  | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset +  -- now there is at least one non-replicated dimension +  | otherwise = +      simplifyArray array $ \(Array sh' strides' offset' vec') unrepSize _ _ _ -> unsafePerformIO $ do +        let ndims' = length sh' +        VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh -> +          VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides -> +          VS.unsafeWith vec' $ \pv -> +            let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) +            in (`scaleval` unrepSize) . valbackconv +                 <$> fred (fromIntegral ndims') psh pstrides (ptrconv pv') + +-- TODO: test this function +-- | Find extremum (minindex ("argmin") or maxindex) in full array +{-# NOINLINE vectorExtremumOp #-} +vectorExtremumOp :: forall a b n. Storable a +                 => (Ptr a -> Ptr b) +                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel +                 -> Array n a -> [Int]  -- result length: n +vectorExtremumOp ptrconv fextrem array@(Array sh strides _ _) +  | null sh = [] +  | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array" +  -- now the input array is nonempty +  | all (== 0) strides = 0 <$ sh +  -- now there is at least one non-replicated dimension +  | otherwise = +      simplifyArray array $ \(Array sh' strides' offset' vec') _ upindex _ _ -> unsafePerformIO $ do +        let ndims' = length sh' +        outvR <- VSM.unsafeNew (length sh') +        VSM.unsafeWith outvR $ \poutv -> +          VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh -> +          VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides -> +          VS.unsafeWith vec' $ \pv -> +            let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) +            in fextrem poutv (fromIntegral ndims') psh pstrides (ptrconv pv') +        upindex . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outvR + +{-# NOINLINE vectorDotprodInnerOp #-} +vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) +                     => SNat n +                     -> (a -> b) +                     -> (Ptr a -> Ptr b) +                     -> (SNat n -> Array n a -> Array n a -> Array n a)  -- ^ elementwise multiplication +                     -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant +                     -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                     -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel +                     -> Array (n + 1) a -> Array (n + 1) a -> Array n a +vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner +    arr1@(Array sh1 strides1 offset1 vec1) +    arr2@(Array sh2 strides2 offset2 vec2) +  | null sh1 || null sh2 = error "unreachable" +  | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | last sh1 <= 0 = arrayFromConstant (init sh1) 0 +  | any (<= 0) (init sh1) = Array (init sh1) (0 <$ init strides1) 0 VS.empty +  -- now the input arrays are nonempty +  | last sh1 == 1 = +      fmul sn (Array (init sh1) (init strides1) offset1 vec1) +              (Array (init sh2) (init strides2) offset2 vec2) +  | last strides1 == 0 = +      fmul sn +        (Array (init sh1) (init strides1) offset1 vec1) +        (vectorRedInnerOp sn valconv ptrconv fscale fred arr2) +  | last strides2 == 0 = +      fmul sn +        (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) +        (Array (init sh2) (init strides2) offset2 vec2) +  -- now there is useful dotprod work along the inner dimension +  | otherwise = unsafePerformIO $ do +      let inrank = fromSNat' sn + 1 +      outv <- VSM.unsafeNew (product (init sh1)) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> +        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> +        VS.unsafeWith vec1 $ \pvec1 -> +        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> +        VS.unsafeWith vec2 $ \pvec2 -> +          fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) +                    pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) +                    pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) +      arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv + +mulWithInt :: Num a => a -> Int -> a +mulWithInt a i = a * fromIntegral i + + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_binary_" ++ atCName arithtype +          c_ss_str = varE (aboNumOp arithop) +          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |] +               ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM intTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_ibinary_" ++ atCName arithtype +          c_ss_str = varE (aiboNumOp arithop) +          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) +          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) +          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |] +               ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_fbinary_" ++ atCName arithtype +          c_ss_str = varE (afboNumOp arithop) +          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |] +               ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |] +               ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |] +               ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let scaleVar = case arithop of +                       RO_SUM -> varE 'mulWithInt +                       RO_PRODUCT -> varE '(^) +      let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype)) +          namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype)) +          c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +          c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) +      sequence [SigD name1 <$> +                     [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |] +               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |] +                   return $ FunD name1 [Clause [] (NormalB body) []] +               ,SigD namefull <$> +                     [t| forall n. SNat n -> Array n $ttyp -> $ttyp |] +               ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |] +                   return $ FunD namefull [Clause [] (NormalB body) []] +               ]) + +$(fmap concat . forM typesList $ \arithtype -> +    fmap concat . forM ["min", "max"] $ \fname -> do +      let ttyp = conT (atType arithtype) +          name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) +      sequence [SigD name <$> +                     [t| forall n. Array n $ttyp -> [Int] |] +               ,do body <- [| vectorExtremumOp id $c_op |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +        name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype)) +        c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype)) +        mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) +        c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) +        c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM))) +    sequence [SigD name <$> +                   [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |] +             ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |] +                 return $ FunD name [Clause [] (NormalB body) []]]) + +foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO () +foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO () + +statisticsEnable :: Bool -> IO () +statisticsEnable b = c_stats_enable (if b then 1 else 0) + +-- | Consumes the log: one particular event will only ever be printed once, +-- even if statisticsPrintAll is called multiple times. +statisticsPrintAll :: IO () +statisticsPrintAll = do +  hFlush stdout  -- lower the chance of overlapping output +  c_stats_print_all + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) +              => (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) +              -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) +              -> (SNat n -> Array n i -> Array n i) +intWidBranch1 f32 f64 sn +  | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr f32 +  | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr f64 +  | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) +              => (i -> i -> i)  -- ss +                 -- int32 +              -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- sv +              -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- vs +              -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- vv +                 -- int64 +              -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- sv +              -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- vs +              -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- vv +              -> (SNat n -> Array n i -> Array n i -> Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn +  | finiteBitSize (undefined :: i) == 32 = liftOpEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32 +  | finiteBitSize (undefined :: i) == 64 = liftOpEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64 +  | otherwise = error "Unsupported Int width" + +intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) +                 => -- int32 +                    (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                    -- int64 +                 -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> (SNat n -> Array (n + 1) i -> Array n i) +intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 +  | otherwise = error "Unsupported Int width" + +intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i) +                    => (i -> Int -> i)  -- ^ scale op +                       -- int32 +                    -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel +                       -- int64 +                    -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel +                    -> (SNat n -> Array n i -> i) +intWidBranchRedFull fsc fred32 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 +  | otherwise = error "Unsupported Int width" + +intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i) +                 => -- int32 +                    (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel +                    -- int64 +                 -> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel +                 -> (Array n i -> [Int]) +intWidBranchExtr fextr32 fextr64 +  | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32 +  | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 +  | otherwise = error "Unsupported Int width" + +intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i) +                    => -- int32 +                       (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant +                    -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                    -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel +                       -- int64 +                    -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant +                    -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                    -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel +                    -> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i) +intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 +  | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64 +  | otherwise = error "Unsupported Int width" + +class NumElt a where +  numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a +  numEltSub :: SNat n -> Array n a -> Array n a -> Array n a +  numEltMul :: SNat n -> Array n a -> Array n a -> Array n a +  numEltNeg :: SNat n -> Array n a -> Array n a +  numEltAbs :: SNat n -> Array n a -> Array n a +  numEltSignum :: SNat n -> Array n a -> Array n a +  numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a +  numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a +  numEltSumFull :: SNat n -> Array n a -> a +  numEltProductFull :: SNat n -> Array n a -> a +  numEltMinIndex :: SNat n -> Array n a -> [Int] +  numEltMaxIndex :: SNat n -> Array n a -> [Int] +  numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a + +instance NumElt Int32 where +  numEltAdd = addVectorInt32 +  numEltSub = subVectorInt32 +  numEltMul = mulVectorInt32 +  numEltNeg = negVectorInt32 +  numEltAbs = absVectorInt32 +  numEltSignum = signumVectorInt32 +  numEltSum1Inner = sum1VectorInt32 +  numEltProduct1Inner = product1VectorInt32 +  numEltSumFull = sumFullVectorInt32 +  numEltProductFull = productFullVectorInt32 +  numEltMinIndex _ = minindexVectorInt32 +  numEltMaxIndex _ = maxindexVectorInt32 +  numEltDotprodInner = dotprodinnerVectorInt32 + +instance NumElt Int64 where +  numEltAdd = addVectorInt64 +  numEltSub = subVectorInt64 +  numEltMul = mulVectorInt64 +  numEltNeg = negVectorInt64 +  numEltAbs = absVectorInt64 +  numEltSignum = signumVectorInt64 +  numEltSum1Inner = sum1VectorInt64 +  numEltProduct1Inner = product1VectorInt64 +  numEltSumFull = sumFullVectorInt64 +  numEltProductFull = productFullVectorInt64 +  numEltMinIndex _ = minindexVectorInt64 +  numEltMaxIndex _ = maxindexVectorInt64 +  numEltDotprodInner = dotprodinnerVectorInt64 + +instance NumElt Float where +  numEltAdd = addVectorFloat +  numEltSub = subVectorFloat +  numEltMul = mulVectorFloat +  numEltNeg = negVectorFloat +  numEltAbs = absVectorFloat +  numEltSignum = signumVectorFloat +  numEltSum1Inner = sum1VectorFloat +  numEltProduct1Inner = product1VectorFloat +  numEltSumFull = sumFullVectorFloat +  numEltProductFull = productFullVectorFloat +  numEltMinIndex _ = minindexVectorFloat +  numEltMaxIndex _ = maxindexVectorFloat +  numEltDotprodInner = dotprodinnerVectorFloat + +instance NumElt Double where +  numEltAdd = addVectorDouble +  numEltSub = subVectorDouble +  numEltMul = mulVectorDouble +  numEltNeg = negVectorDouble +  numEltAbs = absVectorDouble +  numEltSignum = signumVectorDouble +  numEltSum1Inner = sum1VectorDouble +  numEltProduct1Inner = product1VectorDouble +  numEltSumFull = sumFullVectorDouble +  numEltProductFull = productFullVectorDouble +  numEltMinIndex _ = minindexVectorDouble +  numEltMaxIndex _ = maxindexVectorDouble +  numEltDotprodInner = dotprodinnerVectorDouble + +instance NumElt Int where +  numEltAdd = intWidBranch2 @Int (+) +                (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) +                (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @Int (-) +                (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) +                (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @Int (*) +                (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) +                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed1 @Int +                      (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) +                      (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) +  numEltProduct1Inner = intWidBranchRed1 @Int +                          (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) +                          (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) +  numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) +  numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) +  numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 +  numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 +  numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 +                                                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 + +instance NumElt CInt where +  numEltAdd = intWidBranch2 @CInt (+) +                (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) +                (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @CInt (-) +                (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) +                (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @CInt (*) +                (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) +                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed1 @CInt +                      (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) +                      (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) +  numEltProduct1Inner = intWidBranchRed1 @CInt +                          (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) +                          (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) +  numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) +  numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) +  numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 +  numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 +  numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 +                                                 (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 + +class NumElt a => IntElt a where +  intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a +  intEltRem :: SNat n -> Array n a -> Array n a -> Array n a + +instance IntElt Int32 where +  intEltQuot = quotVectorInt32 +  intEltRem = remVectorInt32 + +instance IntElt Int64 where +  intEltQuot = quotVectorInt64 +  intEltRem = remVectorInt64 + +instance IntElt Int where +  intEltQuot = intWidBranch2 @Int quot +                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) +                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) +  intEltRem = intWidBranch2 @Int rem +                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) +                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +instance IntElt CInt where +  intEltQuot = intWidBranch2 @CInt quot +                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) +                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) +  intEltRem = intWidBranch2 @CInt rem +                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) +                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +class NumElt a => FloatElt a where +  floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a +  floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a +  floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a +  floatEltRecip :: SNat n -> Array n a -> Array n a +  floatEltExp :: SNat n -> Array n a -> Array n a +  floatEltLog :: SNat n -> Array n a -> Array n a +  floatEltSqrt :: SNat n -> Array n a -> Array n a +  floatEltSin :: SNat n -> Array n a -> Array n a +  floatEltCos :: SNat n -> Array n a -> Array n a +  floatEltTan :: SNat n -> Array n a -> Array n a +  floatEltAsin :: SNat n -> Array n a -> Array n a +  floatEltAcos :: SNat n -> Array n a -> Array n a +  floatEltAtan :: SNat n -> Array n a -> Array n a +  floatEltSinh :: SNat n -> Array n a -> Array n a +  floatEltCosh :: SNat n -> Array n a -> Array n a +  floatEltTanh :: SNat n -> Array n a -> Array n a +  floatEltAsinh :: SNat n -> Array n a -> Array n a +  floatEltAcosh :: SNat n -> Array n a -> Array n a +  floatEltAtanh :: SNat n -> Array n a -> Array n a +  floatEltLog1p :: SNat n -> Array n a -> Array n a +  floatEltExpm1 :: SNat n -> Array n a -> Array n a +  floatEltLog1pexp :: SNat n -> Array n a -> Array n a +  floatEltLog1mexp :: SNat n -> Array n a -> Array n a +  floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a + +instance FloatElt Float where +  floatEltDiv = divVectorFloat +  floatEltPow = powVectorFloat +  floatEltLogbase = logbaseVectorFloat +  floatEltRecip = recipVectorFloat +  floatEltExp = expVectorFloat +  floatEltLog = logVectorFloat +  floatEltSqrt = sqrtVectorFloat +  floatEltSin = sinVectorFloat +  floatEltCos = cosVectorFloat +  floatEltTan = tanVectorFloat +  floatEltAsin = asinVectorFloat +  floatEltAcos = acosVectorFloat +  floatEltAtan = atanVectorFloat +  floatEltSinh = sinhVectorFloat +  floatEltCosh = coshVectorFloat +  floatEltTanh = tanhVectorFloat +  floatEltAsinh = asinhVectorFloat +  floatEltAcosh = acoshVectorFloat +  floatEltAtanh = atanhVectorFloat +  floatEltLog1p = log1pVectorFloat +  floatEltExpm1 = expm1VectorFloat +  floatEltLog1pexp = log1pexpVectorFloat +  floatEltLog1mexp = log1mexpVectorFloat +  floatEltAtan2 = atan2VectorFloat + +instance FloatElt Double where +  floatEltDiv = divVectorDouble +  floatEltPow = powVectorDouble +  floatEltLogbase = logbaseVectorDouble +  floatEltRecip = recipVectorDouble +  floatEltExp = expVectorDouble +  floatEltLog = logVectorDouble +  floatEltSqrt = sqrtVectorDouble +  floatEltSin = sinVectorDouble +  floatEltCos = cosVectorDouble +  floatEltTan = tanVectorDouble +  floatEltAsin = asinVectorDouble +  floatEltAcos = acosVectorDouble +  floatEltAtan = atanVectorDouble +  floatEltSinh = sinhVectorDouble +  floatEltCosh = coshVectorDouble +  floatEltTanh = tanhVectorDouble +  floatEltAsinh = asinhVectorDouble +  floatEltAcosh = acoshVectorDouble +  floatEltAtanh = atanhVectorDouble +  floatEltLog1p = log1pVectorDouble +  floatEltExpm1 = expm1VectorDouble +  floatEltLog1pexp = log1pexpVectorDouble +  floatEltLog1mexp = log1mexpVectorDouble +  floatEltAtan2 = atan2VectorDouble diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs index 78d5365..dad65f9 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs @@ -1,13 +1,13 @@  {-# LANGUAGE ForeignFunctionInterface #-}  {-# LANGUAGE TemplateHaskell #-} -module Data.Array.Mixed.Internal.Arith.Foreign where +module Data.Array.Strided.Arith.Internal.Foreign where  import Data.Int  import Foreign.C.Types  import Foreign.Ptr  import Language.Haskell.TH -import Data.Array.Mixed.Internal.Arith.Lists +import Data.Array.Strided.Arith.Internal.Lists  $(do diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs index 370b708..910a77c 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Lists.hs +++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs @@ -1,12 +1,12 @@  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE TemplateHaskell #-} -module Data.Array.Mixed.Internal.Arith.Lists where +module Data.Array.Strided.Arith.Internal.Lists where  import Data.Char  import Data.Int  import Language.Haskell.TH -import Data.Array.Mixed.Internal.Arith.Lists.TH +import Data.Array.Strided.Arith.Internal.Lists.TH  data ArithType = ArithType diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs index a156e29..b8f6a3d 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs +++ b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs @@ -1,5 +1,5 @@  {-# LANGUAGE TemplateHaskellQuotes #-} -module Data.Array.Mixed.Internal.Arith.Lists.TH where +module Data.Array.Strided.Arith.Internal.Lists.TH where  import Control.Monad  import Control.Monad.IO.Class diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs new file mode 100644 index 0000000..a772aaf --- /dev/null +++ b/ops/Data/Array/Strided/Array.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Data.Array.Strided.Array where + +import qualified Data.List.NonEmpty as NE +import Data.Proxy +import qualified Data.Vector.Storable as VS +import Foreign.Storable +import GHC.TypeLits + + +data Array (n :: Nat) a = Array +  { arrShape :: ![Int] +  , arrStrides :: ![Int] +  , arrOffset :: !Int +  , arrValues :: !(VS.Vector a) +  } + +-- | Takes a vector in normalised order (inner dimension, i.e. last in the +-- list, iterates fastest). +arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a +arrayFromVector sh vec +  | VS.length vec == shsize +  , length sh == fromIntegral (natVal (Proxy @n)) +  = Array sh strides 0 vec +  | otherwise = error $ "arrayFromVector: Shape " ++ show sh ++ " does not match vector length " ++ show (VS.length vec) +  where +    shsize = product sh +    strides = NE.tail (NE.scanr (*) 1 sh) + +arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a +arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x) + +arrayRevDims :: [Bool] -> Array n a -> Array n a +arrayRevDims bs (Array sh strides offset vec) +  | length bs == length sh = +      Array sh +            (zipWith (\b s -> if b then -s else s) bs strides) +            (offset + sum (zipWith3 (\b n s -> if b then (n - 1) * s else 0) bs sh strides)) +            vec +  | otherwise = error $ "arrayRevDims: " ++ show (length bs) ++ " booleans given but rank " ++ show (length sh) diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 19a61ab..ecd3ba7 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -30,9 +30,6 @@ flag nonportable-simd  library    exposed-modules:      Data.Array.Mixed.Internal.Arith -    Data.Array.Mixed.Internal.Arith.Foreign -    Data.Array.Mixed.Internal.Arith.Lists -    Data.Array.Mixed.Internal.Arith.Lists.TH      Data.Array.Mixed.Lemmas      Data.Array.Mixed.Permutation      Data.Array.Mixed.Shape @@ -52,6 +49,8 @@ library        Data.Array.Nested.Trace.TH    build-depends: +    strided-array-ops, +      base >=4.18 && <4.21,      deepseq,      ghc-typelits-knownnat, @@ -60,6 +59,27 @@ library      template-haskell,      vector    hs-source-dirs: src + +  default-language: Haskell2010 +  ghc-options: -Wall +  other-extensions: TemplateHaskell + +library strided-array-ops +  exposed-modules: +    Data.Array.Strided +    Data.Array.Strided.Array +    Data.Array.Strided.Arith +    Data.Array.Strided.Arith.Internal +    Data.Array.Strided.Arith.Internal.Foreign +    Data.Array.Strided.Arith.Internal.Lists +    Data.Array.Strided.Arith.Internal.Lists.TH +  build-depends: +    base, +    ghc-typelits-knownnat, +    ghc-typelits-natnormalise, +    template-haskell, +    vector +  hs-source-dirs: ops    c-sources: cbits/arith.c    cc-options: -O3 -Wall -Wextra -std=c99 @@ -112,6 +132,7 @@ benchmark bench    main-is: Main.hs    build-depends:      ox-arrays, +    strided-array-ops,      base,      hmatrix,      orthotope, diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 27ebb64..f7a76bc 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -1,929 +1,23 @@ -{-# LANGUAGE DataKinds #-}  {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Mixed.Internal.Arith where -import Control.Monad (forM, guard)  import Data.Array.Internal qualified as OI  import Data.Array.Internal.RankedG qualified as RG  import Data.Array.Internal.RankedS qualified as RS -import Data.Bifunctor (second) -import Data.Bits -import Data.Int -import Data.List (sort) -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types -import Foreign.Marshal.Alloc (alloca) -import Foreign.Ptr -import Foreign.Storable (Storable(sizeOf), peek, poke) -import GHC.TypeLits -import GHC.TypeNats qualified as TypeNats -import Language.Haskell.TH -import System.IO (hFlush, stdout) -import System.IO.Unsafe -import Data.Array.Mixed.Internal.Arith.Foreign -import Data.Array.Mixed.Internal.Arith.Lists -import Data.Array.Mixed.Types (fromSNat') +import Data.Array.Strided qualified as AS --- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) --- TODO: test all the cases of this thing with various input strides -liftVEltwise1 :: (Storable a, Storable b) -              => SNat n -              -> (VS.Vector a -> VS.Vector b) -              -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) -  | Just (blockOff, blockSz) <- stridesDense sh offset strides = -      let vec' = f (VS.slice blockOff blockSz vec) -      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) -  | otherwise = RS.fromVector sh (f (RS.toVector arr)) +liftO1 :: (AS.Array n a -> AS.Array n' b) +       -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO --- TODO: test all the cases of this thing with various input strides -{-# NOINLINE liftOpEltwise1 #-} -liftOpEltwise1 :: (Storable a, Storable b) -               => SNat n -               -> (Ptr a -> Ptr a') -               -> (Ptr b -> Ptr b') -               -> (Int64 -> Ptr b' -> Ptr Int64 -> Ptr Int64 -> Ptr a' -> IO ()) -               -> RS.Array n a -> RS.Array n b -liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) -  -- TODO: less code duplication between these two branches -  | Just (blockOff, blockSz) <- stridesDense sh offset strides = -      if blockSz == 0 -        then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty)) -        else unsafePerformIO $ do -               outv <- VSM.unsafeNew blockSz -               VSM.unsafeWith outv $ \poutv -> -                 VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> -                   VS.unsafeWith (VS.singleton 1) $ \pstrides -> -                     VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> -                       cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) -               RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv -  | otherwise = unsafePerformIO $ do -      outv <- VSM.unsafeNew (product sh) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> -          VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> -            VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> -              cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv) -      RS.fromVector sh <$> VS.unsafeFreeze outv - --- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: Storable a -              => SNat n -              -> (a -> b) -              -> (Ptr a -> Ptr b) -              -> (a -> a -> a) -              -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ sv -              -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- ^ vs -              -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ vv -              -> RS.Array n a -> RS.Array n a -> RS.Array n a -liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv -    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) -    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) -  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 -  | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty)) -  | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of -      (Just (_, 1), Just (_, 1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars -        let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2)) -        in RS.A (RG.A sh1 (OI.T strides1 0 vec')) - -      (Just (_, 1), Just (blockOff, blockSz)) ->  -- scalar * dense -        let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2) -            RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2' -        in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec)) - -      (Just (_, 1), Nothing) ->  -- scalar * array -        wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2 - -      (Just (blockOff, blockSz), Just (_, 1)) ->  -- dense * scalar -        let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1) -            RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2) -        in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec)) - -      (Nothing, Just (_, 1)) ->  -- array * scalar -        wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2) - -      (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) -        | blockSz1 == blockSz2  -- not sure if this check is necessary, might be implied by the strides check -        , strides1 == strides2 -        ->  -- dense * dense but the strides match -          let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) -              arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) -              RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2' -          in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec)) - -      (_, _) ->  -- fallback case -        wrapBinaryVV sn ptrconv f_vv arr1 arr2 - --- | Given shape vector, offset and stride vector, check whether this virtual --- vector uses a dense subarray of its backing array. If so, the first index --- and the number of elements in this subarray is returned. --- This excludes any offset. -stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int) -stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0) -stridesDense sh offsetNeg stridesNeg = -  -- First reverse all dimensions with negative stride, so that the first used -  -- value is at 'offset' and the rest is >= offset. -  let (offset, strides) = flipReverseds sh offsetNeg stridesNeg -  in -- sort dimensions on their stride, ascending, dropping any zero strides -     case filter ((/= 0) . fst) (sort (zip strides sh)) of -       [] -> Just (offset, 1) -       (1, n) : pairs -> (offset,) <$> checkCover n pairs -       _ -> Nothing  -- if the smallest stride is not 1, it will never be dense -  where -    -- Given size of currently densely covered region at beginning of the -    -- array and the remaining (stride, size) pairs with all strides >=1, -    -- return whether this all together covers a dense prefix of the array. If -    -- it does, return the number of elements in this prefix. -    checkCover :: Int -> [(Int, Int)] -> Maybe Int -    checkCover block [] = Just block -    checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs - -    -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0 -    flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int]) -    flipReverseds [] off [] = (off, []) -    flipReverseds (n : sh') off (s : str') -      | s >= 0 = second (s :) (flipReverseds sh' off str') -      | otherwise = -          let off' = off + (n - 1) * s -          in second ((-s) :) (flipReverseds sh' off' str') -    flipReverseds _ _ _ = error "flipReverseds: invalid arguments" - -{-# NOINLINE wrapBinarySV #-} -wrapBinarySV :: Storable a -             => SNat n -             -> (a -> b) -             -> (Ptr a -> Ptr b) -             -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -             -> a -> RS.Array n a -             -> RS.Array n a -wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) = -  unsafePerformIO $ do -    outv <- VSM.unsafeNew (product sh) -    VSM.unsafeWith outv $ \poutv -> -      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> -        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> -          VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> -            cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv) -    RS.fromVector sh <$> VS.unsafeFreeze outv - -wrapBinaryVS :: Storable a -             => SNat n -             -> (a -> b) -             -> (Ptr a -> Ptr b) -             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -             -> RS.Array n a -> a -             -> RS.Array n a -wrapBinaryVS sn valconv ptrconv cf_strided arr y = -  wrapBinarySV sn valconv ptrconv -               (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr - --- | This function assumes that the two shapes are equal. -{-# NOINLINE wrapBinaryVV #-} -wrapBinaryVV :: Storable a -             => SNat n -             -> (Ptr a -> Ptr b) -             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -             -> RS.Array n a -> RS.Array n a -             -> RS.Array n a -wrapBinaryVV sn@SNat ptrconv cf_strided -    (RS.A (RG.A sh (OI.T strides1 offset1 vec1))) -    (RS.A (RG.A _  (OI.T strides2 offset2 vec2))) = -  unsafePerformIO $ do -    outv <- VSM.unsafeNew (product sh) -    VSM.unsafeWith outv $ \poutv -> -      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> -      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 -> -      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 -> -      VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 -> -      VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 -> -        cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2) -    RS.fromVector sh <$> VS.unsafeFreeze outv - -{-# NOINLINE vectorOp1 #-} -vectorOp1 :: forall a b. Storable a -          => (Ptr a -> Ptr b) -          -> (Int64 -> Ptr b -> Ptr b -> IO ()) -          -> VS.Vector a -> VS.Vector a -vectorOp1 ptrconv f v = unsafePerformIO $ do -  outv <- VSM.unsafeNew (VS.length v) -  VSM.unsafeWith outv $ \poutv -> -    VS.unsafeWith v $ \pv -> -      f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) -  VS.unsafeFreeze outv - --- | If two vectors are given, assumes that they have the same length. -{-# NOINLINE vectorOp2 #-} -vectorOp2 :: forall a b. Storable a -          => (a -> b) -          -> (Ptr a -> Ptr b) -          -> (a -> a -> a) -          -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- sv -          -> (Int64 -> Ptr b -> Ptr b -> b -> IO ())  -- vs -          -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ())  -- vv -          -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a -vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases -  (Left x) (Left y) -> VS.singleton (fss x y) - -  (Left x) (Right vy) -> -    unsafePerformIO $ do -      outv <- VSM.unsafeNew (VS.length vy) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith vy $ \pvy -> -          fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) -      VS.unsafeFreeze outv - -  (Right vx) (Left y) -> -    unsafePerformIO $ do -      outv <- VSM.unsafeNew (VS.length vx) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith vx $ \pvx -> -          fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) -      VS.unsafeFreeze outv - -  (Right vx) (Right vy) -    | VS.length vx == VS.length vy -> -        unsafePerformIO $ do -          outv <- VSM.unsafeNew (VS.length vx) -          VSM.unsafeWith outv $ \poutv -> -            VS.unsafeWith vx $ \pvx -> -              VS.unsafeWith vy $ \pvy -> -                fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) -          VS.unsafeFreeze outv -    | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) - --- TODO: test handling of negative strides --- | Reduce along the inner dimension -{-# NOINLINE vectorRedInnerOp #-} -vectorRedInnerOp :: forall a b n. (Num a, Storable a) -                 => SNat n -                 -> (a -> b) -                 -> (Ptr a -> Ptr b) -                 -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant -                 -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel -                 -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) -  | null sh = error "unreachable" -  | last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0]) -  | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty)) -  -- now the input array is nonempty -  | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) -  | last strides == 0 = -      liftVEltwise1 sn -        (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) -        (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) -  -- now there is useful work along the inner dimension -  | otherwise = -      let -- replicated dimensions: dimensions with zero stride. The reduction -          -- kernel need not concern itself with those (and in fact has a -          -- precondition that there are no such dimensions in its input). -          replDims = map (== 0) strides -          -- filter out replicated dimensions -          (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] -          -- replace replicated dimensions with ones -          shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims -          ndimsF = length shF  -- > 0, otherwise `last strides == 0` - -          -- reversed dimensions: dimensions with negative stride. Reversal is -          -- irrelevant for a reduction, and indeed the kernel has a -          -- precondition that there are no such dimensions. -          revDims = map (< 0) stridesF -          stridesR = map abs stridesF -          offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) -          -- The *R values give an array with strides all > 0, hence the -          -- left-most element is at offsetR. -      in unsafePerformIO $ do -           outvR <- VSM.unsafeNew (product (init shF)) -           VSM.unsafeWith outvR $ \poutvR -> -             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> -               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> -                 VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> -                   fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR) -           TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> -             RS.stretch (init sh)  -- replicate to original shape -               . RS.reshape (init shOnes)  -- add 1-sized dimensions where the original was replicated -               . RS.rev (map fst (filter snd (zip [0..] revDims)))  -- re-reverse the correct dimensions -               . RS.fromVector @_ @lenFm1 (init shF)  -- the partially-reversed result array -               <$> VS.unsafeFreeze outvR - --- TODO: test handling of negative strides --- | Reduce full array -{-# NOINLINE vectorRedFullOp #-} -vectorRedFullOp :: forall a b n. (Num a, Storable a) -                => SNat n -                -> (a -> Int -> a) -                -> (b -> a) -                -> (Ptr a -> Ptr b) -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel -                -> RS.Array n a -> a -vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec))) -  | null sh = vec VS.! offset  -- 0D array has one element -  | any (<= 0) sh = 0 -  -- now the input array is nonempty -  | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset -  -- now there is at least one non-replicated dimension -  | otherwise = -      let -- replicated dimensions: dimensions with zero stride. The reduction -          -- kernel need not concern itself with those (and in fact has a -          -- precondition that there are no such dimensions in its input). -          replDims = map (== 0) strides -          -- filter out replicated dimensions -          (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] -          ndimsF = length shF  -- > 0, otherwise `all (== 0) strides` -          -- we should scale up the output this many times to account for the replicated dimensions -          multiplier = product [n | (n, True) <- zip sh replDims] - -          -- reversed dimensions: dimensions with negative stride. Reversal is -          -- irrelevant for a reduction, and indeed the kernel has a -          -- precondition that there are no such dimensions. -          revDims = map (< 0) stridesF -          stridesR = map abs stridesF -          offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) -          -- The *R values give an array with strides all > 0, hence the -          -- left-most element is at offsetR. -      in unsafePerformIO $ do -           VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> -             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> -               VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> -                 (`scaleval` multiplier) . valbackconv -                   <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) - --- TODO: test this function --- | Find extremum (minindex ("argmin") or maxindex) in full array -{-# NOINLINE vectorExtremumOp #-} -vectorExtremumOp :: forall a b n. Storable a -                 => (Ptr a -> Ptr b) -                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel -                 -> RS.Array n a -> [Int]  -- result length: n -vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) -  | null sh = [] -  | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array" -  -- now the input array is nonempty -  | all (== 0) strides = 0 <$ sh -  -- now there is at least one non-replicated dimension -  | otherwise = -      let -- replicated dimensions: dimensions with zero stride. The extremum -          -- kernel need not concern itself with those (and in fact has a -          -- precondition that there are no such dimensions in its input). -          replDims = map (== 0) strides -          -- filter out replicated dimensions -          (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] -          ndimsF = length shF  -- > 0, because not all strides were <=0 - -          -- un-reverse reversed dimensions -          revDims = map (< 0) stridesF -          stridesR = map abs stridesF -          offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) - -          -- function to insert zeros in replicated-out dimensions -          insertZeros :: [Bool] -> [Int] -> [Int] -          insertZeros [] idx = idx -          insertZeros (True : repls) idx = 0 : insertZeros repls idx -          insertZeros (False : repls) (i : idx) = i : insertZeros repls idx -          insertZeros (_:_) [] = error "unreachable" -      in unsafePerformIO $ do -           outvR <- VSM.unsafeNew (length shF) -           VSM.unsafeWith outvR $ \poutvR -> -             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> -               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> -                 VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> -                   fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) -           insertZeros replDims -             . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF  -- re-reverse the reversed dimensions -             . map (fromIntegral @Int64 @Int) -             . VS.toList -             <$> VS.unsafeFreeze outvR - -vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) -                     => SNat n -                     -> (a -> b) -                     -> (Ptr a -> Ptr b) -                     -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a)  -- ^ elementwise multiplication -                     -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant -                     -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel -                     -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel -                     -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a -vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner -    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) -    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) -  | null sh1 || null sh2 = error "unreachable" -  | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 -  | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0]) -  | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty)) -  -- now the input arrays are nonempty -  | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2) -  | last strides1 == 0 = -      fmul sn -        (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1))) -        (vectorRedInnerOp sn valconv ptrconv fscale fred arr2) -  | last strides2 == 0 = -      fmul sn -        (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) -        (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2))) -  -- now there is useful dotprod work along the inner dimension -  | otherwise = unsafePerformIO $ do -      let inrank = fromSNat' sn + 1 -      outv <- VSM.unsafeNew (product (init sh1)) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> -        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> -        VS.unsafeWith vec1 $ \pvec1 -> -        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> -        VS.unsafeWith vec2 $ \pvec2 -> -          fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) -                    pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) -                    pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) -      RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv - -{-# NOINLINE dotScalarVector #-} -dotScalarVector :: forall a b. (Num a, Storable a) -                => Int -> (Ptr a -> Ptr b) -                -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel -                -> a -> VS.Vector a -> a -dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do -  alloca @a $ \pout -> do -    alloca @Int64 $ \pshape -> do -      poke pshape (fromIntegral @Int @Int64 len) -      alloca @Int64 $ \pstride -> do -        poke pstride 1 -        VS.unsafeWith vec $ \pvec -> -          fred 1 (ptrconv pout) pshape pstride (ptrconv pvec) -    res <- peek pout -    return (scalar * res) - -{-# NOINLINE dotVectorVector #-} -dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) -                -> (Int64 -> Ptr b -> Ptr b -> IO b)  -- ^ dotprod kernel -                -> VS.Vector a -> VS.Vector a -> a -dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do -  VS.unsafeWith vec1 $ \pvec1 -> -    VS.unsafeWith vec2 $ \pvec2 -> -      valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2) - -{-# NOINLINE dotVectorVectorStrided #-} -dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) -                       -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b)  -- ^ dotprod kernel -                       -> Int -> Int -> VS.Vector a -                       -> Int -> Int -> VS.Vector a -                       -> a -dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do -  VS.unsafeWith vec1 $ \pvec1 -> -    VS.unsafeWith vec2 $ \pvec2 -> -      valbackconv <$> fdot (fromIntegral @Int @Int64 len) -                           (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1) -                           (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2) - -flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -       ->  Int64 -> Ptr a -> Ptr a -> a -> IO () -flipOp f n out v s = f n out s v - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          cnamebase = "c_binary_" ++ atCName arithtype -          c_ss_str = varE (aboNumOp arithop) -          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM intTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          cnamebase = "c_ibinary_" ++ atCName arithtype -          c_ss_str = varE (aiboNumOp arithop) -          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) -          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) -          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          cnamebase = "c_fbinary_" ++ atCName arithtype -          c_ss_str = varE (afboNumOp arithop) -          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -mulWithInt :: Num a => a -> Int -> a -mulWithInt a i = a * fromIntegral i - -scaleFromSVStrided :: (Int64 -> Ptr Int64 -> Ptr a -> a -> Ptr Int64 -> Ptr a -> IO ()) -                   -> Int64 -> Ptr a -> a -> Ptr a -> IO () -scaleFromSVStrided fsv n out x ys = -  VS.unsafeWith (VS.singleton n) $ \psh -> -    VS.unsafeWith (VS.singleton 1) $ \pstrides -> -      fsv 1 psh out x pstrides ys - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let scaleVar = case arithop of -                       RO_SUM -> varE 'mulWithInt -                       RO_PRODUCT -> varE '(^) -      let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype)) -          namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype)) -          c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) -          c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) -          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) -      sequence [SigD name1 <$> -                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> vectorRedInnerOp sn id id (scaleFromSVStrided $c_scale_op) $c_op1 |] -                   return $ FunD name1 [Clause [] (NormalB body) []] -               ,SigD namefull <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |] -               ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |] -                   return $ FunD namefull [Clause [] (NormalB body) []] -               ]) - -$(fmap concat . forM typesList $ \arithtype -> -    fmap concat . forM ["min", "max"] $ \fname -> do -      let ttyp = conT (atType arithtype) -          name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) -      sequence [SigD name <$> -                     [t| forall n. RS.Array n $ttyp -> [Int] |] -               ,do body <- [| vectorExtremumOp id $c_op |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -        name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype)) -        c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype)) -        mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) -        c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) -        c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM))) -    sequence [SigD name <$> -                   [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] -             ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |] -                 return $ FunD name [Clause [] (NormalB body) []]]) - -foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO () -foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO () - -statisticsEnable :: Bool -> IO () -statisticsEnable b = c_stats_enable (if b then 1 else 0) - --- | Consumes the log: one particular event will only ever be printed once, --- even if statisticsPrintAll is called multiple times. -statisticsPrintAll :: IO () -statisticsPrintAll = do -  hFlush stdout  -- lower the chance of overlapping output -  c_stats_print_all - --- This branch is ostensibly a runtime branch, but will (hopefully) be --- constant-folded away by GHC. -intWidBranch1 :: forall i n. (FiniteBits i, Storable i) -              => (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -              -> (SNat n -> RS.Array n i -> RS.Array n i) -intWidBranch1 f32 f64 sn -  | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr castPtr f32 -  | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr castPtr f64 -  | otherwise = error "Unsupported Int width" - -intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) -              => (i -> i -> i)  -- ss -                 -- int32 -              -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- sv -              -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ())  -- vs -              -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- vv -                 -- int64 -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- sv -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv -              -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) -intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn -  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32 -  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64 -  | otherwise = error "Unsupported Int width" - -intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) -                 => -- int32 -                    (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant -                 -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel -                    -- int64 -                 -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant -                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel -                 -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn -  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 -  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 -  | otherwise = error "Unsupported Int width" - -intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i) -                    => (i -> Int -> i)  -- ^ scale op -                       -- int32 -                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32)  -- ^ reduction kernel -                       -- int64 -                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64)  -- ^ reduction kernel -                    -> (SNat n -> RS.Array n i -> i) -intWidBranchRedFull fsc fred32 fred64 sn -  | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32 -  | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 -  | otherwise = error "Unsupported Int width" - -intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i) -                 => -- int32 -                    (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ extremum kernel -                    -- int64 -                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ extremum kernel -                 -> (RS.Array n i -> [Int]) -intWidBranchExtr fextr32 fextr64 -  | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32 -  | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 -  | otherwise = error "Unsupported Int width" - -intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i) -                    => -- int32 -                       (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant -                    -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel -                    -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ dotprod kernel -                       -- int64 -                    -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant -                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel -                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ dotprod kernel -                    -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn -  | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 -  | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64 -  | otherwise = error "Unsupported Int width" - -class NumElt a where -  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a -  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  numEltSumFull :: SNat n -> RS.Array n a -> a -  numEltProductFull :: SNat n -> RS.Array n a -> a -  numEltMinIndex :: SNat n -> RS.Array n a -> [Int] -  numEltMaxIndex :: SNat n -> RS.Array n a -> [Int] -  numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a - -instance NumElt Int32 where -  numEltAdd = addVectorInt32 -  numEltSub = subVectorInt32 -  numEltMul = mulVectorInt32 -  numEltNeg = negVectorInt32 -  numEltAbs = absVectorInt32 -  numEltSignum = signumVectorInt32 -  numEltSum1Inner = sum1VectorInt32 -  numEltProduct1Inner = product1VectorInt32 -  numEltSumFull = sumFullVectorInt32 -  numEltProductFull = productFullVectorInt32 -  numEltMinIndex _ = minindexVectorInt32 -  numEltMaxIndex _ = maxindexVectorInt32 -  numEltDotprodInner = dotprodinnerVectorInt32 - -instance NumElt Int64 where -  numEltAdd = addVectorInt64 -  numEltSub = subVectorInt64 -  numEltMul = mulVectorInt64 -  numEltNeg = negVectorInt64 -  numEltAbs = absVectorInt64 -  numEltSignum = signumVectorInt64 -  numEltSum1Inner = sum1VectorInt64 -  numEltProduct1Inner = product1VectorInt64 -  numEltSumFull = sumFullVectorInt64 -  numEltProductFull = productFullVectorInt64 -  numEltMinIndex _ = minindexVectorInt64 -  numEltMaxIndex _ = maxindexVectorInt64 -  numEltDotprodInner = dotprodinnerVectorInt64 - -instance NumElt Float where -  numEltAdd = addVectorFloat -  numEltSub = subVectorFloat -  numEltMul = mulVectorFloat -  numEltNeg = negVectorFloat -  numEltAbs = absVectorFloat -  numEltSignum = signumVectorFloat -  numEltSum1Inner = sum1VectorFloat -  numEltProduct1Inner = product1VectorFloat -  numEltSumFull = sumFullVectorFloat -  numEltProductFull = productFullVectorFloat -  numEltMinIndex _ = minindexVectorFloat -  numEltMaxIndex _ = maxindexVectorFloat -  numEltDotprodInner = dotprodinnerVectorFloat - -instance NumElt Double where -  numEltAdd = addVectorDouble -  numEltSub = subVectorDouble -  numEltMul = mulVectorDouble -  numEltNeg = negVectorDouble -  numEltAbs = absVectorDouble -  numEltSignum = signumVectorDouble -  numEltSum1Inner = sum1VectorDouble -  numEltProduct1Inner = product1VectorDouble -  numEltSumFull = sumFullVectorDouble -  numEltProductFull = productFullVectorDouble -  numEltMinIndex _ = minindexVectorDouble -  numEltMaxIndex _ = maxindexVectorDouble -  numEltDotprodInner = dotprodinnerVectorDouble - -instance NumElt Int where -  numEltAdd = intWidBranch2 @Int (+) -                (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) -                (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) -  numEltSub = intWidBranch2 @Int (-) -                (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) -                (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) -  numEltMul = intWidBranch2 @Int (*) -                (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) -                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) -  numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) -  numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) -  numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) -  numEltSum1Inner = intWidBranchRed1 @Int -                      (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) -                      (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) -  numEltProduct1Inner = intWidBranchRed1 @Int -                          (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT)) -                          (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT)) -  numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) -  numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) -  numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 -  numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprodInner = intWidBranchDotprod @Int (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 -                                                (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 - -instance NumElt CInt where -  numEltAdd = intWidBranch2 @CInt (+) -                (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) -                (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) -  numEltSub = intWidBranch2 @CInt (-) -                (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) -                (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) -  numEltMul = intWidBranch2 @CInt (*) -                (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) -                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) -  numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) -  numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) -  numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) -  numEltSum1Inner = intWidBranchRed1 @CInt -                      (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) -                      (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) -  numEltProduct1Inner = intWidBranchRed1 @CInt -                          (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT)) -                          (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT)) -  numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) -  numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) -  numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 -  numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 -                                                 (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 - -class NumElt a => IntElt a where -  intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - -instance IntElt Int32 where -  intEltQuot = quotVectorInt32 -  intEltRem = remVectorInt32 - -instance IntElt Int64 where -  intEltQuot = quotVectorInt64 -  intEltRem = remVectorInt64 - -instance IntElt Int where -  intEltQuot = intWidBranch2 @Int quot -                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) -                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) -  intEltRem = intWidBranch2 @Int rem -                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) -                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) - -instance IntElt CInt where -  intEltQuot = intWidBranch2 @CInt quot -                 (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) -                 (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) -  intEltRem = intWidBranch2 @CInt rem -                (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) -                (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) - -class NumElt a => FloatElt a where -  floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - -instance FloatElt Float where -  floatEltDiv = divVectorFloat -  floatEltPow = powVectorFloat -  floatEltLogbase = logbaseVectorFloat -  floatEltRecip = recipVectorFloat -  floatEltExp = expVectorFloat -  floatEltLog = logVectorFloat -  floatEltSqrt = sqrtVectorFloat -  floatEltSin = sinVectorFloat -  floatEltCos = cosVectorFloat -  floatEltTan = tanVectorFloat -  floatEltAsin = asinVectorFloat -  floatEltAcos = acosVectorFloat -  floatEltAtan = atanVectorFloat -  floatEltSinh = sinhVectorFloat -  floatEltCosh = coshVectorFloat -  floatEltTanh = tanhVectorFloat -  floatEltAsinh = asinhVectorFloat -  floatEltAcosh = acoshVectorFloat -  floatEltAtanh = atanhVectorFloat -  floatEltLog1p = log1pVectorFloat -  floatEltExpm1 = expm1VectorFloat -  floatEltLog1pexp = log1pexpVectorFloat -  floatEltLog1mexp = log1mexpVectorFloat -  floatEltAtan2 = atan2VectorFloat - -instance FloatElt Double where -  floatEltDiv = divVectorDouble -  floatEltPow = powVectorDouble -  floatEltLogbase = logbaseVectorDouble -  floatEltRecip = recipVectorDouble -  floatEltExp = expVectorDouble -  floatEltLog = logVectorDouble -  floatEltSqrt = sqrtVectorDouble -  floatEltSin = sinVectorDouble -  floatEltCos = cosVectorDouble -  floatEltTan = tanVectorDouble -  floatEltAsin = asinVectorDouble -  floatEltAcos = acosVectorDouble -  floatEltAtan = atanVectorDouble -  floatEltSinh = sinhVectorDouble -  floatEltCosh = coshVectorDouble -  floatEltTanh = tanhVectorDouble -  floatEltAsinh = asinhVectorDouble -  floatEltAcosh = acoshVectorDouble -  floatEltAtanh = atanhVectorDouble -  floatEltLog1p = log1pVectorDouble -  floatEltExpm1 = expm1VectorDouble -  floatEltLog1pexp = log1pexpVectorDouble -  floatEltLog1mexp = log1mexpVectorDouble -  floatEltAtan2 = atan2VectorDouble +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) +       -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index 71bdc1f..204c1d8 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -34,6 +34,7 @@ import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types +import Data.Array.Strided.Arith  type XArray :: [Maybe Nat] -> Type -> Type @@ -240,7 +241,7 @@ transpose2 ssh1 ssh2 (XArray arr)  sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a  sumFull _ (XArray arr) =    S.unScalar $ -    numEltSum1Inner (SNat @0) $ +    liftO1 (numEltSum1Inner (SNat @0)) $        S.fromVector [product (S.shapeL arr)] $          S.toVector arr @@ -256,7 +257,7 @@ sumInner ssh ssh' arr          go (XArray arr')            | Refl <- lemRankApp ssh ssh'F            , let sn = listxRank (let StaticShX l = ssh in l) -          = XArray (numEltSum1Inner sn arr') +          = XArray (liftO1 (numEltSum1Inner sn) arr')      in go $         transpose2 ssh'F ssh $ diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index db13da4..9869cba 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -103,7 +103,6 @@ module Data.Array.Nested (  import Prelude hiding (mappend, mconcat) -import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types @@ -112,6 +111,7 @@ import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Ranked  import Data.Array.Nested.Internal.Shape  import Data.Array.Nested.Internal.Shaped +import Data.Array.Strided.Arith  import Foreign.Storable  import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 80d581e..eb452dd 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -49,6 +49,7 @@ import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Lemmas +import Data.Array.Strided.Arith  -- TODO:  --   sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -225,52 +226,52 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_    | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2  instance (NumElt a, PrimElt a) => Num (Mixed sh a) where -  (+) = mliftNumElt2 numEltAdd -  (-) = mliftNumElt2 numEltSub -  (*) = mliftNumElt2 numEltMul -  negate = mliftNumElt1 numEltNeg -  abs = mliftNumElt1 numEltAbs -  signum = mliftNumElt1 numEltSignum +  (+) = mliftNumElt2 (liftO2 . numEltAdd) +  (-) = mliftNumElt2 (liftO2 . numEltSub) +  (*) = mliftNumElt2 (liftO2 . numEltMul) +  negate = mliftNumElt1 (liftO1 . numEltNeg) +  abs = mliftNumElt1 (liftO1 . numEltAbs) +  signum = mliftNumElt1 (liftO1 . numEltSignum)    -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS    fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"  instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" -  recip = mliftNumElt1 floatEltRecip -  (/) = mliftNumElt2 floatEltDiv +  recip = mliftNumElt1 (liftO1 . floatEltRecip) +  (/) = mliftNumElt2 (liftO2 . floatEltDiv)  instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" -  exp = mliftNumElt1 floatEltExp -  log = mliftNumElt1 floatEltLog -  sqrt = mliftNumElt1 floatEltSqrt +  exp = mliftNumElt1 (liftO1 . floatEltExp) +  log = mliftNumElt1 (liftO1 . floatEltLog) +  sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) -  (**) = mliftNumElt2 floatEltPow -  logBase = mliftNumElt2 floatEltLogbase +  (**) = mliftNumElt2 (liftO2 . floatEltPow) +  logBase = mliftNumElt2 (liftO2 . floatEltLogbase) -  sin = mliftNumElt1 floatEltSin -  cos = mliftNumElt1 floatEltCos -  tan = mliftNumElt1 floatEltTan -  asin = mliftNumElt1 floatEltAsin -  acos = mliftNumElt1 floatEltAcos -  atan = mliftNumElt1 floatEltAtan -  sinh = mliftNumElt1 floatEltSinh -  cosh = mliftNumElt1 floatEltCosh -  tanh = mliftNumElt1 floatEltTanh -  asinh = mliftNumElt1 floatEltAsinh -  acosh = mliftNumElt1 floatEltAcosh -  atanh = mliftNumElt1 floatEltAtanh -  log1p = mliftNumElt1 floatEltLog1p -  expm1 = mliftNumElt1 floatEltExpm1 -  log1pexp = mliftNumElt1 floatEltLog1pexp -  log1mexp = mliftNumElt1 floatEltLog1mexp +  sin = mliftNumElt1 (liftO1 . floatEltSin) +  cos = mliftNumElt1 (liftO1 . floatEltCos) +  tan = mliftNumElt1 (liftO1 . floatEltTan) +  asin = mliftNumElt1 (liftO1 . floatEltAsin) +  acos = mliftNumElt1 (liftO1 . floatEltAcos) +  atan = mliftNumElt1 (liftO1 . floatEltAtan) +  sinh = mliftNumElt1 (liftO1 . floatEltSinh) +  cosh = mliftNumElt1 (liftO1 . floatEltCosh) +  tanh = mliftNumElt1 (liftO1 . floatEltTanh) +  asinh = mliftNumElt1 (liftO1 . floatEltAsinh) +  acosh = mliftNumElt1 (liftO1 . floatEltAcosh) +  atanh = mliftNumElt1 (liftO1 . floatEltAtanh) +  log1p = mliftNumElt1 (liftO1 . floatEltLog1p) +  expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) +  log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) +  log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)  mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 intEltQuot -mremArray = mliftNumElt2 intEltRem +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem)  matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 floatEltAtan2 +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)  -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or @@ -867,12 +868,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)  -- | Throws if the array is empty.  mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr) +  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))  -- | Throws if the array is empty.  mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr) +  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))  mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)             => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -883,7 +884,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi        _ :$% _          | sh1 == sh2          , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> -            fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) +            fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))          | otherwise -> error "mdot1Inner: Unequal shapes"        ZSX -> error "unreachable" diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 1c6b789..0a165bc 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -41,13 +41,13 @@ import GHC.TypeNats qualified as TN  import Data.Array.Mixed.XArray (XArray(..))  import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Shape +import Data.Array.Strided.Arith  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 35628db..d7a8ece 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -41,7 +41,6 @@ import GHC.TypeLits  import Data.Array.Mixed.XArray (XArray)  import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape @@ -49,6 +48,7 @@ import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Lemmas  import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Shape +import Data.Array.Strided.Arith  -- | A shape-typed array: the full shape of the array (the sizes of its  | 
