aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs14
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs12
4 files changed, 25 insertions, 25 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 0b9b8eb..579c0da 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -34,10 +34,10 @@ import Data.Array.Mixed.Internal.Arith.Lists
-- TODO: test all the cases of this thing with various input strides
-liftVEltwise1 :: Storable a
+liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
- -> (VS.Vector a -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a
+ -> (VS.Vector a -> VS.Vector b)
+ -> RS.Array n a -> RS.Array n b
liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| Just (blockOff, blockSz) <- stridesDense sh offset strides =
let vec' = f (VS.slice blockOff blockSz vec)
@@ -45,15 +45,15 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-- TODO: test all the cases of this thing with various input strides
-liftVEltwise2 :: Storable a
+liftVEltwise2 :: (Storable a, Storable b, Storable c)
=> SNat n
- -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c)
+ -> RS.Array n a -> RS.Array n b -> RS.Array n c
liftVEltwise2 SNat f
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
| sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
+ | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))
| otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
(Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index ddc075c..647ea82 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -871,13 +871,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
-mliftPrim :: PrimElt a
- => (a -> a)
- -> Mixed sh a -> Mixed sh a
+mliftPrim :: (PrimElt a, PrimElt b)
+ => (a -> b)
+ -> Mixed sh a -> Mixed sh b
mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
-mliftPrim2 :: PrimElt a
- => (a -> a -> a)
- -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
+ => (a -> b -> c)
+ -> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index e59ac0c..bd37e7a 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -180,14 +180,14 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-arithPromoteRanked :: forall n a. PrimElt a
- => (forall sh. Mixed sh a -> Mixed sh a)
- -> Ranked n a -> Ranked n a
+arithPromoteRanked :: forall n a b.
+ (forall sh. Mixed sh a -> Mixed sh b)
+ -> Ranked n a -> Ranked n b
arithPromoteRanked = coerce
-arithPromoteRanked2 :: forall n a. PrimElt a
- => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a)
- -> Ranked n a -> Ranked n a -> Ranked n a
+arithPromoteRanked2 :: forall n a b c.
+ (forall sh. Mixed sh a -> Mixed sh b -> Mixed sh c)
+ -> Ranked n a -> Ranked n b -> Ranked n c
arithPromoteRanked2 = coerce
instance (NumElt a, PrimElt a, Num a) => Num (Ranked n a) where
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 2c24e6d..f50ed28 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -178,14 +178,14 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-arithPromoteShaped :: forall sh a. PrimElt a
- => (forall shx. Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a
+arithPromoteShaped :: forall sh a b.
+ (forall shx. Mixed shx a -> Mixed shx b)
+ -> Shaped sh a -> Shaped sh b
arithPromoteShaped = coerce
-arithPromoteShaped2 :: forall sh a. PrimElt a
- => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a -> Shaped sh a
+arithPromoteShaped2 :: forall sh a b c.
+ (forall shx. Mixed shx a -> Mixed shx b -> Mixed shx c)
+ -> Shaped sh a -> Shaped sh b -> Shaped sh c
arithPromoteShaped2 = coerce
instance (NumElt a, PrimElt a, Num a) => Num (Shaped sh a) where