diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 00:30:25 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 00:30:25 +0100 | 
| commit | c14017f4bc28951be7e298d01769b5b49384a7c3 (patch) | |
| tree | dd7ea8e90b28e37ac46251d11be2eb6c0ffc699b /src/Data/Array | |
| parent | b0fae0894f4440c6cd9cd74b5a3515baa8bd8c35 (diff) | |
arith: Unary int ops on strided arrays without normalisation
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 23 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 1 | 
2 files changed, 23 insertions, 1 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 734c7cd..123a4b5 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -49,6 +49,26 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | otherwise = RS.fromVector sh (f (RS.toVector arr))  -- TODO: test all the cases of this thing with various input strides +{-# NOINLINE liftOpEltwise1 #-} +liftOpEltwise1 :: (Storable a, Storable b) +               => SNat n +               -> (VS.Vector a -> VS.Vector b) +               -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ()) +               -> RS.Array n a -> RS.Array n b +liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      let vec' = f_vec (VS.slice blockOff blockSz vec) +      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) +  | otherwise = unsafePerformIO $ do +      outv <- VSM.unsafeNew (product sh) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> +          VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> +            VS.unsafeWith vec $ \pv -> +              cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv +      RS.fromVector sh <$> VS.unsafeFreeze outv + +-- TODO: test all the cases of this thing with various input strides  liftVEltwise2 :: (Storable a, Storable b, Storable c)                => SNat n                -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) @@ -421,9 +441,10 @@ $(fmap concat . forM typesList $ \arithtype -> do      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))            c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) +          c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +               ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM floatTypesList $ \arithtype -> do diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ade7ce1..22c5b53 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -16,6 +16,7 @@ $(do          ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])          ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |])          ,("unary_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("unary_" ++ tyn ++ "_strided",   [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])          ,("extremum_min_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) | 
