diff options
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 63 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 7 | 
2 files changed, 70 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 91f994b..a57de1d 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -183,6 +183,34 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride                 . RS.fromVector @_ @lenFm1 (init shF)                 <$> VS.unsafeFreeze outv +-- | Find extremum (argmin or argmax) in full array +{-# NOINLINE vectorExtremumOp #-} +vectorExtremumOp :: forall a b n. Storable a +                 => (Ptr a -> Ptr b) +                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel +                 -> RS.Array n a -> [Int]  -- result length: n +vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) +  | null sh = [] +  | any (<= 0) sh = error "Extremum (argmin/argmax): empty array" +  -- now the input array is nonempty +  | all (<= 0) strides = 0 <$ sh +  | otherwise = +      let -- replicated dimensions: dimensions with zero stride. The extremum +          -- kernel need not concern itself with those (and in fact has a +          -- precondition that there are no such dimensions in its input). +          replDims = map (== 0) strides +          -- filter out replicated dimensions +          (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) +          ndimsF = length shF  -- > 0, because not all strides were <=0 +      in unsafePerformIO $ do +           outv <- VSM.unsafeNew (length shF) +           VSM.unsafeWith outv $ \poutv -> +             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> +                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> +                   fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) +           map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO ()  flipOp f n out v s = f n out s v @@ -246,6 +274,16 @@ $(fmap concat . forM typesList $ \arithtype -> do                 ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]                     return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM typesList $ \arithtype -> +    fmap concat . forM ["min", "max"] $ \fname -> do +      let ttyp = conT (atType arithtype) +          name = mkName ("arg" ++ fname ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) +      sequence [SigD name <$> +                     [t| forall n. RS.Array n $ttyp -> [Int] |] +               ,do body <- [| vectorExtremumOp id $c_op |] +                   return $ FunD name [Clause [] (NormalB body) []]]) +  -- This branch is ostensibly a runtime branch, but will (hopefully) be  -- constant-folded away by GHC.  intWidBranch1 :: forall i n. (FiniteBits i, Storable i) @@ -286,6 +324,17 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn    | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64    | otherwise = error "Unsupported Int width" +intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i) +                 => -- int32 +                    (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ extremum kernel +                    -- int64 +                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ extremum kernel +                 -> (RS.Array n i -> [Int]) +intWidBranchExtr fextr32 fextr64 +  | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32 +  | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 +  | otherwise = error "Unsupported Int width" +  class NumElt a where    numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a    numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a @@ -295,6 +344,8 @@ class NumElt a where    numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a    numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a    numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltArgMin :: RS.Array n a -> [Int] +  numEltArgMax :: RS.Array n a -> [Int]  instance NumElt Int32 where    numEltAdd = addVectorInt32 @@ -305,6 +356,8 @@ instance NumElt Int32 where    numEltSignum = signumVectorInt32    numEltSum1Inner = sum1VectorInt32    numEltProduct1Inner = product1VectorInt32 +  numEltArgMin = argminVectorInt32 +  numEltArgMax = argmaxVectorInt32  instance NumElt Int64 where    numEltAdd = addVectorInt64 @@ -315,6 +368,8 @@ instance NumElt Int64 where    numEltSignum = signumVectorInt64    numEltSum1Inner = sum1VectorInt64    numEltProduct1Inner = product1VectorInt64 +  numEltArgMin = argminVectorInt64 +  numEltArgMax = argmaxVectorInt64  instance NumElt Float where    numEltAdd = addVectorFloat @@ -325,6 +380,8 @@ instance NumElt Float where    numEltSignum = signumVectorFloat    numEltSum1Inner = sum1VectorFloat    numEltProduct1Inner = product1VectorFloat +  numEltArgMin = argminVectorFloat +  numEltArgMax = argmaxVectorFloat  instance NumElt Double where    numEltAdd = addVectorDouble @@ -335,6 +392,8 @@ instance NumElt Double where    numEltSignum = signumVectorDouble    numEltSum1Inner = sum1VectorDouble    numEltProduct1Inner = product1VectorDouble +  numEltArgMin = argminVectorDouble +  numEltArgMax = argmaxVectorDouble  instance NumElt Int where    numEltAdd = intWidBranch2 @Int (+) @@ -355,6 +414,8 @@ instance NumElt Int where    numEltProduct1Inner = intWidBranchRed @Int                            (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))                            (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) +  numEltArgMin = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 +  numEltArgMax = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) @@ -375,6 +436,8 @@ instance NumElt CInt where    numEltProduct1Inner = intWidBranchRed @CInt                            (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))                            (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) +  numEltArgMin = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 +  numEltArgMax = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64  class FloatElt a where    floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 6fc7229..0bd72e8 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -53,3 +53,10 @@ $(fmap concat . forM typesList $ \arithtype -> do      let base = "reduce_" ++ atCName arithtype      pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>        [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> +    fmap concat . forM ["min", "max"] $ \fname -> do +      let ttyp = conT (atType arithtype) +      let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype +      pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +        [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) | 
