diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 23:49:56 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-16 23:50:07 +0100 | 
| commit | 71908c23307952fac26a4e24066e064d9cbb71c0 (patch) | |
| tree | 6bcb91ff66d0c6623c36ed965514d06e76585a14 /src/Data/Array/Mixed | |
| parent | c14017f4bc28951be7e298d01769b5b49384a7c3 (diff) | |
arith: Only strided unary int ops
This should have negligible overhead and will save a whole bunch of C
code duplication when the FUnops are also converted to strided form.
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 44 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 1 | 
2 files changed, 25 insertions, 20 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 123a4b5..58108f2 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -52,20 +52,27 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))  {-# NOINLINE liftOpEltwise1 #-}  liftOpEltwise1 :: (Storable a, Storable b)                 => SNat n -               -> (VS.Vector a -> VS.Vector b) -               -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ()) +               -> (Ptr a -> Ptr a') +               -> (Ptr b -> Ptr b') +               -> (Int64 -> Ptr b' -> Ptr Int64 -> Ptr Int64 -> Ptr a' -> IO ())                 -> RS.Array n a -> RS.Array n b -liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) -  | Just (blockOff, blockSz) <- stridesDense sh offset strides = -      let vec' = f_vec (VS.slice blockOff blockSz vec) -      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) +liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) +  -- TODO: less code duplication between these two branches +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = unsafePerformIO $ do +      outv <- VSM.unsafeNew blockSz +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> +          VS.unsafeWith (VS.singleton 1) $ \pstrides -> +            VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> +              cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) +      RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv    | otherwise = unsafePerformIO $ do        outv <- VSM.unsafeNew (product sh)        VSM.unsafeWith outv $ \poutv ->          VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->            VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->              VS.unsafeWith vec $ \pv -> -              cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv +              cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv)        RS.fromVector sh <$> VS.unsafeFreeze outv  -- TODO: test all the cases of this thing with various input strides @@ -440,11 +447,10 @@ $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))            c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |] +               ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM floatTypesList $ \arithtype -> do @@ -506,12 +512,12 @@ $(fmap concat . forM typesList $ \arithtype -> do  -- This branch is ostensibly a runtime branch, but will (hopefully) be  -- constant-folded away by GHC.  intWidBranch1 :: forall i n. (FiniteBits i, Storable i) -              => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) +              => (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())                -> (SNat n -> RS.Array n i -> RS.Array n i)  intWidBranch1 f32 f64 sn -  | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) -  | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) +  | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr castPtr f32 +  | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr castPtr f64    | otherwise = error "Unsupported Int width"  intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -666,9 +672,9 @@ instance NumElt Int where    numEltMul = intWidBranch2 @Int (*)                  (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))                  (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) -  numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) -  numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) -  numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))    numEltSum1Inner = intWidBranchRed1 @Int                        (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))                        (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) @@ -692,9 +698,9 @@ instance NumElt CInt where    numEltMul = intWidBranch2 @CInt (*)                  (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))                  (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) -  numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) -  numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) -  numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))    numEltSum1Inner = intWidBranchRed1 @CInt                        (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))                        (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 22c5b53..b53eb36 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -15,7 +15,6 @@ $(do          [("binary_" ++ tyn ++ "_vv",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])          ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])          ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) -        ,("unary_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])          ,("unary_" ++ tyn ++ "_strided",   [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) | 
