diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 00:30:25 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 00:30:25 +0100 |
commit | c14017f4bc28951be7e298d01769b5b49384a7c3 (patch) | |
tree | dd7ea8e90b28e37ac46251d11be2eb6c0ffc699b /src/Data/Array/Mixed/Internal/Arith.hs | |
parent | b0fae0894f4440c6cd9cd74b5a3515baa8bd8c35 (diff) |
arith: Unary int ops on strided arrays without normalisation
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 734c7cd..123a4b5 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -49,6 +49,26 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | otherwise = RS.fromVector sh (f (RS.toVector arr)) -- TODO: test all the cases of this thing with various input strides +{-# NOINLINE liftOpEltwise1 #-} +liftOpEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ()) + -> RS.Array n a -> RS.Array n b +liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f_vec (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = unsafePerformIO $ do + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> + VS.unsafeWith vec $ \pv -> + cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv + RS.fromVector sh <$> VS.unsafeFreeze outv + +-- TODO: test all the cases of this thing with various input strides liftVEltwise2 :: (Storable a, Storable b, Storable c) => SNat n -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) @@ -421,9 +441,10 @@ $(fmap concat . forM typesList $ \arithtype -> do fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) + c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |] return $ FunD name [Clause [] (NormalB body) []]]) $(fmap concat . forM floatTypesList $ \arithtype -> do |