diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-17 13:08:13 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-17 13:08:13 +0200 | 
| commit | 9b0651bf19e889dfb28ba81b6ada25b27b0e6071 (patch) | |
| tree | 53dafa55d48d9b73f148426ba5a308cc8cfa6410 /src/Data/Array/Mixed | |
| parent | 3d48baae00c066f43fa2205b22f0357f069888f2 (diff) | |
sumAllPrim
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 150 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 4 | 
3 files changed, 126 insertions, 38 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 579c0da..d547084 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -33,6 +33,9 @@ import Data.Array.Mixed.Internal.Arith.Foreign  import Data.Array.Mixed.Internal.Arith.Lists +-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition + +  -- TODO: test all the cases of this thing with various input strides  liftVEltwise1 :: (Storable a, Storable b)                => SNat n @@ -186,7 +189,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride            -- precondition that there are no such dimensions in its input).            replDims = map (== 0) strides            -- filter out replicated dimensions -          (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) +          (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]            -- replace replicated dimensions with ones            shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims            ndimsF = length shF  -- > 0, otherwise `last strides == 0` @@ -213,6 +216,48 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride                 . RS.fromVector @_ @lenFm1 (init shF)  -- the partially-reversed result array                 <$> VS.unsafeFreeze outvR +-- TODO: test handling of negative strides +-- | Reduce full array +{-# NOINLINE vectorRedFullOp #-} +vectorRedFullOp :: forall a b n. (Num a, Storable a) +                => SNat n +                -> (a -> Int -> a) +                -> (b -> a) +                -> (Ptr a -> Ptr b) +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel +                -> RS.Array n a -> a +vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec))) +  | null sh = vec VS.! offset  -- 0D array has one element +  | any (<= 0) sh = 0 +  -- now the input array is nonempty +  | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset +  -- now there is at least one non-replicated dimension +  | otherwise = +      let -- replicated dimensions: dimensions with zero stride. The reduction +          -- kernel need not concern itself with those (and in fact has a +          -- precondition that there are no such dimensions in its input). +          replDims = map (== 0) strides +          -- filter out replicated dimensions +          (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] +          ndimsF = length shF  -- > 0, otherwise `all (== 0) strides` +          -- we should scale up the output this many times to account for the replicated dimensions +          multiplier = product [n | (n, True) <- zip sh replDims] + +          -- reversed dimensions: dimensions with negative stride. Reversal is +          -- irrelevant for a reduction, and indeed the kernel has a +          -- precondition that there are no such dimensions. +          revDims = map (< 0) stridesF +          stridesR = map abs stridesF +          offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) +          -- The *R values give an array with strides all > 0, hence the +          -- left-most element is at offsetR. +      in unsafePerformIO $ do +           VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> +             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> +               VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> +                 (`scaleval` fromIntegral multiplier) . valbackconv +                   <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) +  -- TODO: test this function  -- | Find extremum (minindex ("argmin") or maxindex) in full array  {-# NOINLINE vectorExtremumOp #-} @@ -232,7 +277,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))            -- precondition that there are no such dimensions in its input).            replDims = map (== 0) strides            -- filter out replicated dimensions -          (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) +          (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]            ndimsF = length shF  -- > 0, because not all strides were <=0            -- un-reverse reversed dimensions @@ -380,16 +425,29 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do                 ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]                     return $ FunD name [Clause [] (NormalB body) []]]) +mulWithInt :: Num a => a -> Int -> a +mulWithInt a i = a * fromIntegral i +  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +      let scaleVar = case arithop of +                       RO_SUM -> varE 'mulWithInt +                       RO_PRODUCT -> varE '(^) +      let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype)) +          namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype)) +          c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +          c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))            c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) -      sequence [SigD name <$> +      sequence [SigD name1 <$>                       [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] -                   return $ FunD name [Clause [] (NormalB body) []]]) +               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |] +                   return $ FunD name1 [Clause [] (NormalB body) []] +               ,SigD namefull <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |] +               ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |] +                   return $ FunD namefull [Clause [] (NormalB body) []] +               ])  $(fmap concat . forM typesList $ \arithtype ->      fmap concat . forM ["min", "max"] $ \fname -> do @@ -406,7 +464,7 @@ $(fmap concat . forM typesList $ \arithtype -> do          name = mkName ("dotprodVector" ++ nameBase (atType arithtype))          c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype))          c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided")) -        c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1))) +        c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))      sequence [SigD name <$>                     [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |]               ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |] @@ -439,19 +497,31 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn    | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)    | otherwise = error "Unsupported Int width" -intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) -                => -- int32 -                   (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel -                   -- int64 -                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel -                -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed fsc32 fred32 fsc64 fred64 sn +intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) +                 => -- int32 +                    (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                    -- int64 +                 -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel +                 -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn    | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32    | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64    | otherwise = error "Unsupported Int width" +intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i) +                    => (i -> Int -> i)  -- ^ scale op +                       -- int32 +                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32)  -- ^ reduction kernel +                       -- int64 +                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64)  -- ^ reduction kernel +                    -> (SNat n -> RS.Array n i -> i) +intWidBranchRedFull fsc fred32 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 +  | otherwise = error "Unsupported Int width" +  intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)                   => -- int32                      (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ extremum kernel @@ -487,6 +557,8 @@ class NumElt a where    numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a    numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a    numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltSumFull :: SNat n -> RS.Array n a -> a +  numEltProductFull :: SNat n -> RS.Array n a -> a    numEltMinIndex :: RS.Array n a -> [Int]    numEltMaxIndex :: RS.Array n a -> [Int]    numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a @@ -500,6 +572,8 @@ instance NumElt Int32 where    numEltSignum = signumVectorInt32    numEltSum1Inner = sum1VectorInt32    numEltProduct1Inner = product1VectorInt32 +  numEltSumFull = sumFullVectorInt32 +  numEltProductFull = productFullVectorInt32    numEltMinIndex = minindexVectorInt32    numEltMaxIndex = maxindexVectorInt32    numEltDotprod = dotprodVectorInt32 @@ -513,6 +587,8 @@ instance NumElt Int64 where    numEltSignum = signumVectorInt64    numEltSum1Inner = sum1VectorInt64    numEltProduct1Inner = product1VectorInt64 +  numEltSumFull = sumFullVectorInt64 +  numEltProductFull = productFullVectorInt64    numEltMinIndex = minindexVectorInt64    numEltMaxIndex = maxindexVectorInt64    numEltDotprod = dotprodVectorInt64 @@ -526,6 +602,8 @@ instance NumElt Float where    numEltSignum = signumVectorFloat    numEltSum1Inner = sum1VectorFloat    numEltProduct1Inner = product1VectorFloat +  numEltSumFull = sumFullVectorFloat +  numEltProductFull = productFullVectorFloat    numEltMinIndex = minindexVectorFloat    numEltMaxIndex = maxindexVectorFloat    numEltDotprod = dotprodVectorFloat @@ -539,6 +617,8 @@ instance NumElt Double where    numEltSignum = signumVectorDouble    numEltSum1Inner = sum1VectorDouble    numEltProduct1Inner = product1VectorDouble +  numEltSumFull = sumFullVectorDouble +  numEltProductFull = productFullVectorDouble    numEltMinIndex = minindexVectorDouble    numEltMaxIndex = maxindexVectorDouble    numEltDotprod = dotprodVectorDouble @@ -556,16 +636,18 @@ instance NumElt Int where    numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))    numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))    numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) -  numEltSum1Inner = intWidBranchRed @Int -                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) -                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) -  numEltProduct1Inner = intWidBranchRed @Int -                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) -                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) +  numEltSum1Inner = intWidBranchRed1 @Int +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) +  numEltProduct1Inner = intWidBranchRed1 @Int +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) +  numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) +  numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))    numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided -                                           (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided +  numEltDotprod = intWidBranchDotprod @Int (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided +                                           (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) @@ -580,16 +662,18 @@ instance NumElt CInt where    numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))    numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))    numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) -  numEltSum1Inner = intWidBranchRed @CInt -                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) -                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) -  numEltProduct1Inner = intWidBranchRed @CInt -                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) -                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) +  numEltSum1Inner = intWidBranchRed1 @CInt +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) +  numEltProduct1Inner = intWidBranchRed1 @CInt +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) +  numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) +  numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))    numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided -                                            (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided +  numEltDotprod = intWidBranchDotprod @CInt (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided +                                            (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided  class FloatElt a where    floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index a406dab..ca96093 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -49,9 +49,13 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    let base = "reduce_" ++ atCName arithtype -    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -      [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +    let base1 = "reduce1_" ++ atCName arithtype +        basefull = "reducefull_" ++ atCName arithtype +    sequence +      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$> +         [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |] +      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$> +         [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]])  $(fmap concat . forM typesList $ \arithtype ->      fmap concat . forM ["min", "max"] $ \fname -> do diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index 08295cd..fa753bb 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -240,8 +240,8 @@ transpose2 ssh1 ssh2 (XArray arr)    , let n1 = ssxLength ssh1    = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) -sumFull :: (Storable a, NumElt a) => XArray sh a -> a -sumFull (XArray arr) = +sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a +sumFull _ (XArray arr) =    S.unScalar $      numEltSum1Inner (SNat @0) $        S.fromVector [product (S.shapeL arr)] $ | 
