aboutsummaryrefslogtreecommitdiff
path: root/ops/Data
diff options
context:
space:
mode:
Diffstat (limited to 'ops/Data')
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs69
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs4
2 files changed, 57 insertions, 16 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 5802573..2eb0666 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -27,7 +27,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
+import GHC.TypeNats qualified as TN
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
@@ -42,7 +42,7 @@ import Data.Array.Strided.Array
-- TODO: move this to a utilities module
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
data Dict c where
Dict :: c => Dict c
@@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) =
unrepSize = product [n | (n, True) <- zip sh replDims]
- in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
simplifyArray :: Array n a
@@ -200,7 +200,7 @@ simplifyArray :: Array n a
-> r)
-> r
simplifyArray array k
- | let revDims = map (<0) (arrStrides array)
+ | let revDims = map (< 0) (arrStrides array)
, Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
= k array'
unrepSize
@@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
, let unrepSize = product [n | (n, True) <- zip sh replDims]
- = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
k @lenshF
(Array shF strides1F offset1 vec1)
(Array shF strides2F offset2 vec2)
@@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
VS.unsafeWith vec' $ \pv ->
let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
- TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -396,6 +396,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
Nothing -> error "impossible"
-- TODO: test handling of negative strides
+-- TODO: simplify away normalised dimensions
-- | Reduce full array
{-# NOINLINE vectorRedFullOp #-}
vectorRedFullOp :: forall a b n. (Num a, Storable a)
@@ -490,7 +491,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
- TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -714,6 +715,36 @@ class NumElt a where
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+instance NumElt Int8 where
+ numEltAdd = addVectorInt8
+ numEltSub = subVectorInt8
+ numEltMul = mulVectorInt8
+ numEltNeg = negVectorInt8
+ numEltAbs = absVectorInt8
+ numEltSignum = signumVectorInt8
+ numEltSum1Inner = sum1VectorInt8
+ numEltProduct1Inner = product1VectorInt8
+ numEltSumFull = sumFullVectorInt8
+ numEltProductFull = productFullVectorInt8
+ numEltMinIndex _ = minindexVectorInt8
+ numEltMaxIndex _ = maxindexVectorInt8
+ numEltDotprodInner = dotprodinnerVectorInt8
+
+instance NumElt Int16 where
+ numEltAdd = addVectorInt16
+ numEltSub = subVectorInt16
+ numEltMul = mulVectorInt16
+ numEltNeg = negVectorInt16
+ numEltAbs = absVectorInt16
+ numEltSignum = signumVectorInt16
+ numEltSum1Inner = sum1VectorInt16
+ numEltProduct1Inner = product1VectorInt16
+ numEltSumFull = sumFullVectorInt16
+ numEltProductFull = productFullVectorInt16
+ numEltMinIndex _ = minindexVectorInt16
+ numEltMaxIndex _ = maxindexVectorInt16
+ numEltDotprodInner = dotprodinnerVectorInt16
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -830,6 +861,14 @@ class NumElt a => IntElt a where
intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+instance IntElt Int8 where
+ intEltQuot = quotVectorInt8
+ intEltRem = remVectorInt8
+
+instance IntElt Int16 where
+ intEltQuot = quotVectorInt16
+ intEltRem = remVectorInt16
+
instance IntElt Int32 where
intEltQuot = quotVectorInt32
intEltRem = remVectorInt32
@@ -840,19 +879,19 @@ instance IntElt Int64 where
instance IntElt Int where
intEltQuot = intWidBranch2 @Int quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @Int rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
instance IntElt CInt where
intEltQuot = intWidBranch2 @CInt quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @CInt rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 910a77c..27204d2 100644
--- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -16,7 +16,9 @@ data ArithType = ArithType
intTypesList :: [ArithType]
intTypesList =
- [ArithType ''Int32 "i32"
+ [ArithType ''Int8 "i8"
+ ,ArithType ''Int16 "i16"
+ ,ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
]