Specify the number of threads for parallel sort. (#8735)
* Specify the number of threads for parallel sort. - Pass context object into argsort. - Replace macros with inline functions.
This commit is contained in:
@@ -6,24 +6,25 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/stats.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/algorithm.h" // ArgSort
|
||||
#include "../common/numeric.h" // RunLengthEncode
|
||||
#include "../common/stats.h" // Quantile,WeightedQuantile
|
||||
#include "../common/threading_utils.h" // ParallelFor
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
namespace detail {
|
||||
void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& position,
|
||||
std::vector<size_t>* p_nptr, std::vector<bst_node_t>* p_nidx,
|
||||
std::vector<size_t>* p_ridx) {
|
||||
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
||||
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
|
||||
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_ridx) {
|
||||
auto& nptr = *p_nptr;
|
||||
auto& nidx = *p_nidx;
|
||||
auto& ridx = *p_ridx;
|
||||
ridx = common::ArgSort<size_t>(position);
|
||||
ridx = common::ArgSort<size_t>(ctx, position.cbegin(), position.cend());
|
||||
std::vector<bst_node_t> sorted_pos(position);
|
||||
// permutation
|
||||
for (size_t i = 0; i < position.size(); ++i) {
|
||||
@@ -74,7 +75,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
std::vector<bst_node_t> nidx;
|
||||
std::vector<size_t> nptr;
|
||||
std::vector<size_t> ridx;
|
||||
EncodeTreeLeafHost(*p_tree, position, &nptr, &nidx, &ridx);
|
||||
EncodeTreeLeafHost(ctx, *p_tree, position, &nptr, &nidx, &ridx);
|
||||
size_t n_leaf = nidx.size();
|
||||
if (nptr.empty()) {
|
||||
std::vector<float> quantiles;
|
||||
|
||||
Reference in New Issue
Block a user