diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 00:30:25 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 00:30:25 +0100 |
commit | c14017f4bc28951be7e298d01769b5b49384a7c3 (patch) | |
tree | dd7ea8e90b28e37ac46251d11be2eb6c0ffc699b /src/Data/Array | |
parent | b0fae0894f4440c6cd9cd74b5a3515baa8bd8c35 (diff) |
arith: Unary int ops on strided arrays without normalisation
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 23 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 1 |
2 files changed, 23 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 734c7cd..123a4b5 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -49,6 +49,26 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | otherwise = RS.fromVector sh (f (RS.toVector arr)) -- TODO: test all the cases of this thing with various input strides +{-# NOINLINE liftOpEltwise1 #-} +liftOpEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ()) + -> RS.Array n a -> RS.Array n b +liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f_vec (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = unsafePerformIO $ do + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> + VS.unsafeWith vec $ \pv -> + cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv + RS.fromVector sh <$> VS.unsafeFreeze outv + +-- TODO: test all the cases of this thing with various input strides liftVEltwise2 :: (Storable a, Storable b, Storable c) => SNat n -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) @@ -421,9 +441,10 @@ $(fmap concat . forM typesList $ \arithtype -> do fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) + c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |] return $ FunD name [Clause [] (NormalB body) []]]) $(fmap concat . forM floatTypesList $ \arithtype -> do diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ade7ce1..22c5b53 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -16,6 +16,7 @@ $(do ,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) ,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) ,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) |