diff options
Diffstat (limited to 'ops/Data/Array/Strided')
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 70 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Lists.hs | 4 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Array.hs | 7 |
3 files changed, 61 insertions, 20 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index 313d72f..d94fc65 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} @@ -9,7 +10,6 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Strided.Arith.Internal where @@ -21,20 +21,20 @@ import Data.Int import Data.List (sort, zip4) import Data.Proxy import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types import Foreign.Ptr import Foreign.Storable -import qualified GHC.TypeNats as TypeNats import GHC.TypeLits +import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH import System.IO (hFlush, stdout) import System.IO.Unsafe -import Data.Array.Strided.Array -import Data.Array.Strided.Arith.Internal.Lists import Data.Array.Strided.Arith.Internal.Foreign +import Data.Array.Strided.Arith.Internal.Lists +import Data.Array.Strided.Array -- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition @@ -49,11 +49,11 @@ data Dict c where debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String debugShow (Array sh strides offset vec) = - "Array @" ++ (show (natVal (Proxy @n))) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">" + "Array @" ++ show (natVal (Proxy @n)) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">" -- TODO: test all the cases of this thing with various input strides -liftOpEltwise1 :: (Storable a, Storable b) +liftOpEltwise1 :: Storable a => SNat n -> (Ptr a -> Ptr b) -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) @@ -62,7 +62,7 @@ liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec) | Just (blockOff, blockSz) <- stridesDense sh offset strides = if blockSz == 0 then Array sh (map (const 0) strides) 0 VS.empty - else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [fromIntegral blockSz] [1] blockOff vec) + else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [blockSz] [1] blockOff vec) in Array sh strides (offset - blockOff) resvec | otherwise = wrapUnary sn ptrconv cf_strided arr @@ -174,8 +174,8 @@ unreplicateStrides (Array sh strides offset vec) = reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' reinsertZeros [] [] = [] - reinsertZeros (False : _) [] = error $ "unreplicateStrides: Internal error: reply strides too short" - reinsertZeros [] (_:_) = error $ "unreplicateStrides: Internal error: reply strides too long" + reinsertZeros (False : _) [] = error "unreplicateStrides: Internal error: reply strides too short" + reinsertZeros [] (_:_) = error "unreplicateStrides: Internal error: reply strides too long" unrepSize = product [n | (n, True) <- zip sh replDims] @@ -214,7 +214,7 @@ simplifyArray array k if | sh' /= init (arrShape array') -> error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")" | last (arrStrides array) == 0 -> - error $ "simplifyArray: Internal error: reduction reply handler used while inner stride was 0" + error "simplifyArray: Internal error: reduction reply handler used while inner stride was 0" | otherwise -> arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec')) @@ -253,8 +253,8 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' reinsertZeros [] [] = [] - reinsertZeros (False : _) [] = error $ "simplifyArray2: Internal error: reply strides too short" - reinsertZeros [] (_:_) = error $ "simplifyArray2: Internal error: reply strides too long" + reinsertZeros (False : _) [] = error "simplifyArray2: Internal error: reply strides too short" + reinsertZeros [] (_:_) = error "simplifyArray2: Internal error: reply strides too long" , let unrepSize = product [n | (n, True) <- zip sh replDims] @@ -272,7 +272,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k if | sh' /= init shF -> error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" | last replDims -> - error $ "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated" + error "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated" | otherwise -> arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec')) @@ -673,7 +673,7 @@ intWidBranchRedFull fsc fred32 fred64 sn | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 | otherwise = error "Unsupported Int width" -intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i) +intWidBranchExtr :: forall i n. (FiniteBits i, Storable i) => -- int32 (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel -- int64 @@ -714,6 +714,36 @@ class NumElt a where numEltMaxIndex :: SNat n -> Array n a -> [Int] numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a +instance NumElt Int8 where + numEltAdd = addVectorInt8 + numEltSub = subVectorInt8 + numEltMul = mulVectorInt8 + numEltNeg = negVectorInt8 + numEltAbs = absVectorInt8 + numEltSignum = signumVectorInt8 + numEltSum1Inner = sum1VectorInt8 + numEltProduct1Inner = product1VectorInt8 + numEltSumFull = sumFullVectorInt8 + numEltProductFull = productFullVectorInt8 + numEltMinIndex _ = minindexVectorInt8 + numEltMaxIndex _ = maxindexVectorInt8 + numEltDotprodInner = dotprodinnerVectorInt8 + +instance NumElt Int16 where + numEltAdd = addVectorInt16 + numEltSub = subVectorInt16 + numEltMul = mulVectorInt16 + numEltNeg = negVectorInt16 + numEltAbs = absVectorInt16 + numEltSignum = signumVectorInt16 + numEltSum1Inner = sum1VectorInt16 + numEltProduct1Inner = product1VectorInt16 + numEltSumFull = sumFullVectorInt16 + numEltProductFull = productFullVectorInt16 + numEltMinIndex _ = minindexVectorInt16 + numEltMaxIndex _ = maxindexVectorInt16 + numEltDotprodInner = dotprodinnerVectorInt16 + instance NumElt Int32 where numEltAdd = addVectorInt32 numEltSub = subVectorInt32 @@ -830,6 +860,14 @@ class NumElt a => IntElt a where intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a intEltRem :: SNat n -> Array n a -> Array n a -> Array n a +instance IntElt Int8 where + intEltQuot = quotVectorInt8 + intEltRem = remVectorInt8 + +instance IntElt Int16 where + intEltQuot = quotVectorInt16 + intEltRem = remVectorInt16 + instance IntElt Int32 where intEltQuot = quotVectorInt32 intEltRem = remVectorInt32 diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs index 910a77c..27204d2 100644 --- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs +++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs @@ -16,7 +16,9 @@ data ArithType = ArithType intTypesList :: [ArithType] intTypesList = - [ArithType ''Int32 "i32" + [ArithType ''Int8 "i8" + ,ArithType ''Int16 "i16" + ,ArithType ''Int32 "i32" ,ArithType ''Int64 "i64" ] diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs index df455c7..9280fe0 100644 --- a/ops/Data/Array/Strided/Array.hs +++ b/ops/Data/Array/Strided/Array.hs @@ -1,12 +1,13 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module Data.Array.Strided.Array where -import qualified Data.List.NonEmpty as NE +import Data.List.NonEmpty qualified as NE import Data.Proxy -import qualified Data.Vector.Storable as VS +import Data.Vector.Storable qualified as VS import Foreign.Storable import GHC.TypeLits @@ -30,7 +31,7 @@ arrayFromVector sh vec shsize = product sh strides = NE.tail (NE.scanr (*) 1 sh) -arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a +arrayFromConstant :: Storable a => [Int] -> a -> Array n a arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x) arrayRevDims :: [Bool] -> Array n a -> Array n a |
