diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 12 | 
4 files changed, 25 insertions, 25 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 0b9b8eb..579c0da 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -34,10 +34,10 @@ import Data.Array.Mixed.Internal.Arith.Lists  -- TODO: test all the cases of this thing with various input strides -liftVEltwise1 :: Storable a +liftVEltwise1 :: (Storable a, Storable b)                => SNat n -              -> (VS.Vector a -> VS.Vector a) -              -> RS.Array n a -> RS.Array n a +              -> (VS.Vector a -> VS.Vector b) +              -> RS.Array n a -> RS.Array n b  liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | Just (blockOff, blockSz) <- stridesDense sh offset strides =        let vec' = f (VS.slice blockOff blockSz vec) @@ -45,15 +45,15 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | otherwise = RS.fromVector sh (f (RS.toVector arr))  -- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: Storable a +liftVEltwise2 :: (Storable a, Storable b, Storable c)                => SNat n -              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) -              -> RS.Array n a -> RS.Array n a -> RS.Array n a +              -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) +              -> RS.Array n a -> RS.Array n b -> RS.Array n c  liftVEltwise2 SNat f      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))    | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 -  | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs +  | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))    | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of        (Just (_, 1), Just (_, 1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars          let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index ddc075c..647ea82 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -871,13 +871,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr  mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a  mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP -mliftPrim :: PrimElt a -          => (a -> a) -          -> Mixed sh a -> Mixed sh a +mliftPrim :: (PrimElt a, PrimElt b) +          => (a -> b) +          -> Mixed sh a -> Mixed sh b  mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) -mliftPrim2 :: PrimElt a -           => (a -> a -> a) -           -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) +           => (a -> b -> c) +           -> Mixed sh a -> Mixed sh b -> Mixed sh c  mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =    fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index e59ac0c..bd37e7a 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -180,14 +180,14 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where      = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) -arithPromoteRanked :: forall n a. PrimElt a -                   => (forall sh. Mixed sh a -> Mixed sh a) -                   -> Ranked n a -> Ranked n a +arithPromoteRanked :: forall n a b. +                      (forall sh. Mixed sh a -> Mixed sh b) +                   -> Ranked n a -> Ranked n b  arithPromoteRanked = coerce -arithPromoteRanked2 :: forall n a. PrimElt a -                    => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) -                    -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 :: forall n a b c. +                       (forall sh. Mixed sh a -> Mixed sh b -> Mixed sh c) +                    -> Ranked n a -> Ranked n b -> Ranked n c  arithPromoteRanked2 = coerce  instance (NumElt a, PrimElt a, Num a) => Num (Ranked n a) where diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 2c24e6d..f50ed28 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -178,14 +178,14 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where      = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) -arithPromoteShaped :: forall sh a. PrimElt a -                   => (forall shx. Mixed shx a -> Mixed shx a) -                   -> Shaped sh a -> Shaped sh a +arithPromoteShaped :: forall sh a b. +                      (forall shx. Mixed shx a -> Mixed shx b) +                   -> Shaped sh a -> Shaped sh b  arithPromoteShaped = coerce -arithPromoteShaped2 :: forall sh a. PrimElt a -                    => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) -                    -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 :: forall sh a b c. +                       (forall shx. Mixed shx a -> Mixed shx b -> Mixed shx c) +                    -> Shaped sh a -> Shaped sh b -> Shaped sh c  arithPromoteShaped2 = coerce  instance (NumElt a, PrimElt a, Num a) => Num (Shaped sh a) where | 
