aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal/Arith.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-17 12:04:09 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-17 12:04:09 +0200
commit2ca90987058d14c79cd983ab14ee57949bae2871 (patch)
tree1f81d469fbf9eba174fc0d7c40b5626123830cb8 /src/Data/Array/Mixed/Internal/Arith.hs
parent63b60c06674127e96cebfc3f1e8710f31df379d7 (diff)
Generalise some of the lifting functions to type-changing
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 0b9b8eb..579c0da 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -34,10 +34,10 @@ import Data.Array.Mixed.Internal.Arith.Lists
-- TODO: test all the cases of this thing with various input strides
-liftVEltwise1 :: Storable a
+liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
- -> (VS.Vector a -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a
+ -> (VS.Vector a -> VS.Vector b)
+ -> RS.Array n a -> RS.Array n b
liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| Just (blockOff, blockSz) <- stridesDense sh offset strides =
let vec' = f (VS.slice blockOff blockSz vec)
@@ -45,15 +45,15 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-- TODO: test all the cases of this thing with various input strides
-liftVEltwise2 :: Storable a
+liftVEltwise2 :: (Storable a, Storable b, Storable c)
=> SNat n
- -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c)
+ -> RS.Array n a -> RS.Array n b -> RS.Array n c
liftVEltwise2 SNat f
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
| sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
+ | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))
| otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
(Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))