diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-17 12:04:09 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-17 12:04:09 +0200 | 
| commit | 2ca90987058d14c79cd983ab14ee57949bae2871 (patch) | |
| tree | 1f81d469fbf9eba174fc0d7c40b5626123830cb8 /src/Data/Array/Mixed | |
| parent | 63b60c06674127e96cebfc3f1e8710f31df379d7 (diff) | |
Generalise some of the lifting functions to type-changing
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 | 
1 files changed, 7 insertions, 7 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 0b9b8eb..579c0da 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -34,10 +34,10 @@ import Data.Array.Mixed.Internal.Arith.Lists  -- TODO: test all the cases of this thing with various input strides -liftVEltwise1 :: Storable a +liftVEltwise1 :: (Storable a, Storable b)                => SNat n -              -> (VS.Vector a -> VS.Vector a) -              -> RS.Array n a -> RS.Array n a +              -> (VS.Vector a -> VS.Vector b) +              -> RS.Array n a -> RS.Array n b  liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | Just (blockOff, blockSz) <- stridesDense sh offset strides =        let vec' = f (VS.slice blockOff blockSz vec) @@ -45,15 +45,15 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | otherwise = RS.fromVector sh (f (RS.toVector arr))  -- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: Storable a +liftVEltwise2 :: (Storable a, Storable b, Storable c)                => SNat n -              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) -              -> RS.Array n a -> RS.Array n a -> RS.Array n a +              -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) +              -> RS.Array n a -> RS.Array n b -> RS.Array n c  liftVEltwise2 SNat f      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))    | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 -  | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs +  | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))    | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of        (Just (_, 1), Just (_, 1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars          let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) | 
