diff options
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 32ed355..d2ad61f 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -184,7 +184,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride <$> VS.unsafeFreeze outv -- TODO: test this function --- | Find extremum (argmin or argmax) in full array +-- | Find extremum (minindex ("argmin") or maxindex) in full array {-# NOINLINE vectorExtremumOp #-} vectorExtremumOp :: forall a b n. Storable a => (Ptr a -> Ptr b) @@ -192,7 +192,7 @@ vectorExtremumOp :: forall a b n. Storable a -> RS.Array n a -> [Int] -- result length: n vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) | null sh = [] - | any (<= 0) sh = error "Extremum (argmin/argmax): empty array" + | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array" -- now the input array is nonempty | all (<= 0) strides = 0 <$ sh | otherwise = @@ -283,7 +283,7 @@ $(fmap concat . forM typesList $ \arithtype -> do $(fmap concat . forM typesList $ \arithtype -> fmap concat . forM ["min", "max"] $ \fname -> do let ttyp = conT (atType arithtype) - name = mkName ("arg" ++ fname ++ "Vector" ++ nameBase (atType arithtype)) + name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype)) c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) sequence [SigD name <$> [t| forall n. RS.Array n $ttyp -> [Int] |] @@ -350,8 +350,8 @@ class NumElt a where numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltArgMin :: RS.Array n a -> [Int] - numEltArgMax :: RS.Array n a -> [Int] + numEltMinIndex :: RS.Array n a -> [Int] + numEltMaxIndex :: RS.Array n a -> [Int] instance NumElt Int32 where numEltAdd = addVectorInt32 @@ -362,8 +362,8 @@ instance NumElt Int32 where numEltSignum = signumVectorInt32 numEltSum1Inner = sum1VectorInt32 numEltProduct1Inner = product1VectorInt32 - numEltArgMin = argminVectorInt32 - numEltArgMax = argmaxVectorInt32 + numEltMinIndex = minindexVectorInt32 + numEltMaxIndex = maxindexVectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -374,8 +374,8 @@ instance NumElt Int64 where numEltSignum = signumVectorInt64 numEltSum1Inner = sum1VectorInt64 numEltProduct1Inner = product1VectorInt64 - numEltArgMin = argminVectorInt64 - numEltArgMax = argmaxVectorInt64 + numEltMinIndex = minindexVectorInt64 + numEltMaxIndex = maxindexVectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -386,8 +386,8 @@ instance NumElt Float where numEltSignum = signumVectorFloat numEltSum1Inner = sum1VectorFloat numEltProduct1Inner = product1VectorFloat - numEltArgMin = argminVectorFloat - numEltArgMax = argmaxVectorFloat + numEltMinIndex = minindexVectorFloat + numEltMaxIndex = maxindexVectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -398,8 +398,8 @@ instance NumElt Double where numEltSignum = signumVectorDouble numEltSum1Inner = sum1VectorDouble numEltProduct1Inner = product1VectorDouble - numEltArgMin = argminVectorDouble - numEltArgMax = argmaxVectorDouble + numEltMinIndex = minindexVectorDouble + numEltMaxIndex = maxindexVectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) @@ -420,8 +420,8 @@ instance NumElt Int where numEltProduct1Inner = intWidBranchRed @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - numEltArgMin = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 - numEltArgMax = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 + numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 + numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) @@ -442,8 +442,8 @@ instance NumElt CInt where numEltProduct1Inner = intWidBranchRed @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - numEltArgMin = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 - numEltArgMax = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 + numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 + numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 class FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a |