aboutsummaryrefslogtreecommitdiff
path: root/ops/Data/Array/Strided
diff options
context:
space:
mode:
Diffstat (limited to 'ops/Data/Array/Strided')
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs70
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs4
-rw-r--r--ops/Data/Array/Strided/Array.hs7
3 files changed, 61 insertions, 20 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 313d72f..d94fc65 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
@@ -9,7 +10,6 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Strided.Arith.Internal where
@@ -21,20 +21,20 @@ import Data.Int
import Data.List (sort, zip4)
import Data.Proxy
import Data.Type.Equality
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
-import qualified GHC.TypeNats as TypeNats
import GHC.TypeLits
+import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
-import Data.Array.Strided.Array
-import Data.Array.Strided.Arith.Internal.Lists
import Data.Array.Strided.Arith.Internal.Foreign
+import Data.Array.Strided.Arith.Internal.Lists
+import Data.Array.Strided.Array
-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
@@ -49,11 +49,11 @@ data Dict c where
debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String
debugShow (Array sh strides offset vec) =
- "Array @" ++ (show (natVal (Proxy @n))) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">"
+ "Array @" ++ show (natVal (Proxy @n)) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">"
-- TODO: test all the cases of this thing with various input strides
-liftOpEltwise1 :: (Storable a, Storable b)
+liftOpEltwise1 :: Storable a
=> SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
@@ -62,7 +62,7 @@ liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec)
| Just (blockOff, blockSz) <- stridesDense sh offset strides =
if blockSz == 0
then Array sh (map (const 0) strides) 0 VS.empty
- else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [fromIntegral blockSz] [1] blockOff vec)
+ else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [blockSz] [1] blockOff vec)
in Array sh strides (offset - blockOff) resvec
| otherwise = wrapUnary sn ptrconv cf_strided arr
@@ -174,8 +174,8 @@ unreplicateStrides (Array sh strides offset vec) =
reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
reinsertZeros [] [] = []
- reinsertZeros (False : _) [] = error $ "unreplicateStrides: Internal error: reply strides too short"
- reinsertZeros [] (_:_) = error $ "unreplicateStrides: Internal error: reply strides too long"
+ reinsertZeros (False : _) [] = error "unreplicateStrides: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error "unreplicateStrides: Internal error: reply strides too long"
unrepSize = product [n | (n, True) <- zip sh replDims]
@@ -214,7 +214,7 @@ simplifyArray array k
if | sh' /= init (arrShape array') ->
error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")"
| last (arrStrides array) == 0 ->
- error $ "simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
+ error "simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
| otherwise ->
arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec'))
@@ -253,8 +253,8 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
, let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
reinsertZeros [] [] = []
- reinsertZeros (False : _) [] = error $ "simplifyArray2: Internal error: reply strides too short"
- reinsertZeros [] (_:_) = error $ "simplifyArray2: Internal error: reply strides too long"
+ reinsertZeros (False : _) [] = error "simplifyArray2: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error "simplifyArray2: Internal error: reply strides too long"
, let unrepSize = product [n | (n, True) <- zip sh replDims]
@@ -272,7 +272,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
if | sh' /= init shF ->
error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
| last replDims ->
- error $ "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
+ error "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
| otherwise ->
arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec'))
@@ -673,7 +673,7 @@ intWidBranchRedFull fsc fred32 fred64 sn
| finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
| otherwise = error "Unsupported Int width"
-intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i)
=> -- int32
(forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
-- int64
@@ -714,6 +714,36 @@ class NumElt a where
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+instance NumElt Int8 where
+ numEltAdd = addVectorInt8
+ numEltSub = subVectorInt8
+ numEltMul = mulVectorInt8
+ numEltNeg = negVectorInt8
+ numEltAbs = absVectorInt8
+ numEltSignum = signumVectorInt8
+ numEltSum1Inner = sum1VectorInt8
+ numEltProduct1Inner = product1VectorInt8
+ numEltSumFull = sumFullVectorInt8
+ numEltProductFull = productFullVectorInt8
+ numEltMinIndex _ = minindexVectorInt8
+ numEltMaxIndex _ = maxindexVectorInt8
+ numEltDotprodInner = dotprodinnerVectorInt8
+
+instance NumElt Int16 where
+ numEltAdd = addVectorInt16
+ numEltSub = subVectorInt16
+ numEltMul = mulVectorInt16
+ numEltNeg = negVectorInt16
+ numEltAbs = absVectorInt16
+ numEltSignum = signumVectorInt16
+ numEltSum1Inner = sum1VectorInt16
+ numEltProduct1Inner = product1VectorInt16
+ numEltSumFull = sumFullVectorInt16
+ numEltProductFull = productFullVectorInt16
+ numEltMinIndex _ = minindexVectorInt16
+ numEltMaxIndex _ = maxindexVectorInt16
+ numEltDotprodInner = dotprodinnerVectorInt16
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -830,6 +860,14 @@ class NumElt a => IntElt a where
intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+instance IntElt Int8 where
+ intEltQuot = quotVectorInt8
+ intEltRem = remVectorInt8
+
+instance IntElt Int16 where
+ intEltQuot = quotVectorInt16
+ intEltRem = remVectorInt16
+
instance IntElt Int32 where
intEltQuot = quotVectorInt32
intEltRem = remVectorInt32
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 910a77c..27204d2 100644
--- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -16,7 +16,9 @@ data ArithType = ArithType
intTypesList :: [ArithType]
intTypesList =
- [ArithType ''Int32 "i32"
+ [ArithType ''Int8 "i8"
+ ,ArithType ''Int16 "i16"
+ ,ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
]
diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs
index df455c7..9280fe0 100644
--- a/ops/Data/Array/Strided/Array.hs
+++ b/ops/Data/Array/Strided/Array.hs
@@ -1,12 +1,13 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Strided.Array where
-import qualified Data.List.NonEmpty as NE
+import Data.List.NonEmpty qualified as NE
import Data.Proxy
-import qualified Data.Vector.Storable as VS
+import Data.Vector.Storable qualified as VS
import Foreign.Storable
import GHC.TypeLits
@@ -30,7 +31,7 @@ arrayFromVector sh vec
shsize = product sh
strides = NE.tail (NE.scanr (*) 1 sh)
-arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a
+arrayFromConstant :: Storable a => [Int] -> a -> Array n a
arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x)
arrayRevDims :: [Bool] -> Array n a -> Array n a