diff options
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 0b9b8eb..579c0da 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -34,10 +34,10 @@ import Data.Array.Mixed.Internal.Arith.Lists -- TODO: test all the cases of this thing with various input strides -liftVEltwise1 :: Storable a +liftVEltwise1 :: (Storable a, Storable b) => SNat n - -> (VS.Vector a -> VS.Vector a) - -> RS.Array n a -> RS.Array n a + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | Just (blockOff, blockSz) <- stridesDense sh offset strides = let vec' = f (VS.slice blockOff blockSz vec) @@ -45,15 +45,15 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | otherwise = RS.fromVector sh (f (RS.toVector arr)) -- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: Storable a +liftVEltwise2 :: (Storable a, Storable b, Storable c) => SNat n - -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -> RS.Array n a + -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) + -> RS.Array n a -> RS.Array n b -> RS.Array n c liftVEltwise2 SNat f arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 - | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs + | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty)) | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) |