aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-16 00:30:25 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-16 00:30:25 +0100
commitc14017f4bc28951be7e298d01769b5b49384a7c3 (patch)
treedd7ea8e90b28e37ac46251d11be2eb6c0ffc699b /src/Data/Array
parentb0fae0894f4440c6cd9cd74b5a3515baa8bd8c35 (diff)
arith: Unary int ops on strided arrays without normalisation
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs23
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs1
2 files changed, 23 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 734c7cd..123a4b5 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -49,6 +49,26 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-- TODO: test all the cases of this thing with various input strides
+{-# NOINLINE liftOpEltwise1 #-}
+liftOpEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (VS.Vector a -> VS.Vector b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ())
+ -> RS.Array n a -> RS.Array n b
+liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ let vec' = f_vec (VS.slice blockOff blockSz vec)
+ in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
+ | otherwise = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv
+ RS.fromVector sh <$> VS.unsafeFreeze outv
+
+-- TODO: test all the cases of this thing with various input strides
liftVEltwise2 :: (Storable a, Storable b, Storable c)
=> SNat n
-> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c)
@@ -421,9 +441,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM floatTypesList $ \arithtype -> do
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ade7ce1..22c5b53 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -16,6 +16,7 @@ $(do
,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])