aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs23
1 files changed, 21 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 07d5d8a..042c9d0 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -6,9 +6,10 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+-- {-# OPTIONS_GHC -ddump-simpl -ddump-to-file -ddump-file-prefix=Arith #-}
module Data.Array.Nested.Internal.Arith where
-import Control.Monad (forM, guard)
+import Control.Monad (forM, forM_, guard)
import qualified Data.Array.Internal as OI
import qualified Data.Array.Internal.RankedG as RG
import qualified Data.Array.Internal.RankedS as RS
@@ -309,8 +310,26 @@ instance NumElt Float where
numEltSum1Inner = sum1VectorFloat
numEltProduct1Inner = product1VectorFloat
+hsaddDoubleSV :: Double -> VS.Vector Double -> VS.Vector Double
+hsaddDoubleSV = error "unimplemented"
+
+{-# NOINLINE hsaddDoubleVV #-}
+hsaddDoubleVV :: VS.Vector Double -> VS.Vector Double -> VS.Vector Double
+-- hsaddDoubleVV = VS.zipWith (+)
+hsaddDoubleVV v1 v2 = unsafePerformIO $ do
+ let n = min (VS.length v1) (VS.length v2)
+ dest <- VSM.unsafeNew n
+ forM_ [0 .. n - 1] $ \i -> do
+ VSM.write dest i (v1 VS.! i + v2 VS.! i)
+ VS.unsafeFreeze dest
+
instance NumElt Double where
- numEltAdd = addVectorDouble
+ numEltAdd = \sn -> liftVEltwise2 sn $ \cases
+ (Left x) (Left y) -> VS.singleton (x + y)
+ (Left x) (Right vy) -> hsaddDoubleSV x vy
+ (Right vx) (Left y) -> hsaddDoubleSV y vx
+ (Right vx) (Right vy) -> hsaddDoubleVV vx vy
+ -- numEltAdd = addVectorDouble
numEltSub = subVectorDouble
numEltMul = mulVectorDouble
numEltNeg = negVectorDouble