Remove omp_get_max_threads in CPU predictor. (#7519)

This is part of the on going effort to remove the dependency on global omp variables.
This commit is contained in:
Jiaming Yuan 2022-01-04 22:12:15 +08:00 committed by GitHub
parent 5516281881
commit 68cdbc9c16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 68 deletions

View File

@ -24,7 +24,7 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
bool seed_per_iteration; bool seed_per_iteration;
// number of threads to use if OpenMP is enabled // number of threads to use if OpenMP is enabled
// if equals 0, use system default // if equals 0, use system default
int nthread; int nthread{0};
// primary device, -1 means no gpu. // primary device, -1 means no gpu.
int gpu_id; int gpu_id;
// fail when gpu_id is invalid // fail when gpu_id is invalid

View File

@ -105,11 +105,11 @@ class Predictor {
/* /*
* \brief Runtime parameters. * \brief Runtime parameters.
*/ */
GenericParameter const* generic_param_; GenericParameter const* ctx_;
public: public:
explicit Predictor(GenericParameter const* generic_param) : explicit Predictor(GenericParameter const* ctx) : ctx_{ctx} {}
generic_param_{generic_param} {}
virtual ~Predictor() = default; virtual ~Predictor() = default;
/** /**

View File

@ -150,10 +150,12 @@ class AdapterView {
static size_t constexpr kUnroll = kUnrollLen; static size_t constexpr kUnroll = kUnrollLen;
public: public:
explicit AdapterView(Adapter *adapter, float missing, explicit AdapterView(Adapter *adapter, float missing, common::Span<Entry> workplace,
common::Span<Entry> workplace) int32_t n_threads)
: adapter_{adapter}, missing_{missing}, workspace_{workplace}, : adapter_{adapter},
current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {} missing_{missing},
workspace_{workplace},
current_unroll_(n_threads > 0 ? n_threads : 1, 0) {}
SparsePage::Inst operator[](size_t i) { SparsePage::Inst operator[](size_t i) {
bst_feature_t columns = adapter_->NumColumns(); bst_feature_t columns = adapter_->NumColumns();
auto const &batch = adapter_->Value(); auto const &batch = adapter_->Value();
@ -186,7 +188,7 @@ template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel( void PredictBatchByBlockOfRowsKernel(
DataView batch, std::vector<bst_float> *out_preds, DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) { std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads) {
auto &thread_temp = *p_thread_temp; auto &thread_temp = *p_thread_temp;
int32_t const num_group = model.learner_model_param->num_output_group; int32_t const num_group = model.learner_model_param->num_output_group;
@ -197,7 +199,7 @@ void PredictBatchByBlockOfRowsKernel(
const int num_feature = model.learner_model_param->num_feature; const int num_feature = model.learner_model_param->num_feature;
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size); omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) { common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) {
const size_t batch_offset = block_id * block_of_rows_size; const size_t batch_offset = block_id * block_of_rows_size;
const size_t block_size = const size_t block_size =
std::min(nsize - batch_offset, block_of_rows_size); std::min(nsize - batch_offset, block_of_rows_size);
@ -252,7 +254,7 @@ class CPUPredictor : public Predictor {
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds, void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end) const { int32_t tree_end) const {
const int threads = omp_get_max_threads(); auto const n_threads = this->ctx_->Threads();
constexpr double kDensityThresh = .5; constexpr double kDensityThresh = .5;
size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_, size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_,
static_cast<uint64_t>(1)); static_cast<uint64_t>(1));
@ -261,7 +263,7 @@ class CPUPredictor : public Predictor {
bool blocked = density > kDensityThresh; bool blocked = density > kDensityThresh;
std::vector<RegTree::FVec> feat_vecs; std::vector<RegTree::FVec> feat_vecs;
InitThreadTemp(threads * (blocked ? kBlockOfRowsSize : 1), InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1),
model.learner_model_param->num_feature, &feat_vecs); model.learner_model_param->num_feature, &feat_vecs);
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) { for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(), CHECK_EQ(out_preds->size(),
@ -269,15 +271,14 @@ class CPUPredictor : public Predictor {
model.learner_model_param->num_output_group); model.learner_model_param->num_output_group);
size_t constexpr kUnroll = 8; size_t constexpr kUnroll = 8;
if (blocked) { if (blocked) {
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, kBlockOfRowsSize>(
kBlockOfRowsSize>( SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs,
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, n_threads);
tree_end, &feat_vecs);
} else { } else {
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, 1>( PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, 1>(
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs,
tree_end, &feat_vecs); n_threads);
} }
} }
} }
@ -304,7 +305,7 @@ class CPUPredictor : public Predictor {
const gbm::GBTreeModel &model, float missing, const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const { uint32_t tree_begin, uint32_t tree_end) const {
auto threads = omp_get_max_threads(); auto const n_threads = this->ctx_->Threads();
auto m = dmlc::get<std::shared_ptr<Adapter>>(x); auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model."; << "Number of columns in data must equal to trained model.";
@ -316,14 +317,14 @@ class CPUPredictor : public Predictor {
info.num_row_ = m->NumRows(); info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model); this->InitOutPredictions(info, &(out_preds->predictions), model);
} }
std::vector<Entry> workspace(m->NumColumns() * 8 * threads); std::vector<Entry> workspace(m->NumColumns() * 8 * n_threads);
auto &predictions = out_preds->predictions.HostVector(); auto &predictions = out_preds->predictions.HostVector();
std::vector<RegTree::FVec> thread_temp; std::vector<RegTree::FVec> thread_temp;
InitThreadTemp(threads * kBlockSize, model.learner_model_param->num_feature, InitThreadTemp(n_threads * kBlockSize, model.learner_model_param->num_feature,
&thread_temp); &thread_temp);
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>( PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}), AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads),
&predictions, model, tree_begin, tree_end, &thread_temp); &predictions, model, tree_begin, tree_end, &thread_temp, n_threads);
} }
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m, bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
@ -370,10 +371,10 @@ class CPUPredictor : public Predictor {
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds, void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) const override { const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
const int nthread = omp_get_max_threads(); auto const n_threads = this->ctx_->Threads();
std::vector<RegTree::FVec> feat_vecs; std::vector<RegTree::FVec> feat_vecs;
const int num_feature = model.learner_model_param->num_feature; const int num_feature = model.learner_model_param->num_feature;
InitThreadTemp(nthread, num_feature, &feat_vecs); InitThreadTemp(n_threads, num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
// number of valid trees // number of valid trees
if (ntree_limit == 0 || ntree_limit > model.trees.size()) { if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
@ -386,7 +387,7 @@ class CPUPredictor : public Predictor {
// parallel over local batch // parallel over local batch
auto page = batch.GetView(); auto page = batch.GetView();
const auto nsize = static_cast<bst_omp_uint>(batch.Size()); const auto nsize = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nsize, [&](bst_omp_uint i) { common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
auto ridx = static_cast<size_t>(batch.base_rowid + i); auto ridx = static_cast<size_t>(batch.base_rowid + i);
RegTree::FVec &feats = feat_vecs[tid]; RegTree::FVec &feats = feat_vecs[tid];
@ -411,10 +412,10 @@ class CPUPredictor : public Predictor {
std::vector<bst_float> const *tree_weights, std::vector<bst_float> const *tree_weights,
bool approximate, int condition, bool approximate, int condition,
unsigned condition_feature) const override { unsigned condition_feature) const override {
const int nthread = omp_get_max_threads(); auto const n_threads = this->ctx_->Threads();
const int num_feature = model.learner_model_param->num_feature; const int num_feature = model.learner_model_param->num_feature;
std::vector<RegTree::FVec> feat_vecs; std::vector<RegTree::FVec> feat_vecs;
InitThreadTemp(nthread, num_feature, &feat_vecs); InitThreadTemp(n_threads, num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
// number of valid trees // number of valid trees
if (ntree_limit == 0 || ntree_limit > model.trees.size()) { if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
@ -432,7 +433,7 @@ class CPUPredictor : public Predictor {
std::fill(contribs.begin(), contribs.end(), 0); std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values // initialize tree node mean values
std::vector<std::vector<float>> mean_values(ntree_limit); std::vector<std::vector<float>> mean_values(ntree_limit);
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) { common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
}); });
auto base_margin = info.base_margin_.View(GenericParameter::kCpuId); auto base_margin = info.base_margin_.View(GenericParameter::kCpuId);
@ -441,7 +442,7 @@ class CPUPredictor : public Predictor {
auto page = batch.GetView(); auto page = batch.GetView();
// parallel over local batch // parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.Size()); const auto nsize = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nsize, [&](bst_omp_uint i) { common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
auto row_idx = static_cast<size_t>(batch.base_rowid + i); auto row_idx = static_cast<size_t>(batch.base_rowid + i);
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()]; RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
if (feats.Size() == 0) { if (feats.Size() == 0) {

View File

@ -633,12 +633,12 @@ class GPUPredictor : public xgboost::Predictor {
size_t num_features, size_t num_features,
HostDeviceVector<bst_float>* predictions, HostDeviceVector<bst_float>* predictions,
size_t batch_offset, bool is_dense) const { size_t batch_offset, bool is_dense) const {
batch.offset.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(ctx_->gpu_id);
batch.data.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(ctx_->gpu_id);
const uint32_t BLOCK_THREADS = 128; const uint32_t BLOCK_THREADS = 128;
size_t num_rows = batch.Size(); size_t num_rows = batch.Size();
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
size_t shared_memory_bytes = size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes); SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
@ -694,10 +694,10 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_end - tree_begin == 0) { if (tree_end - tree_begin == 0) {
return; return;
} }
out_preds->SetDevice(generic_param_->gpu_id); out_preds->SetDevice(ctx_->gpu_id);
auto const& info = dmat->Info(); auto const& info = dmat->Info();
DeviceModel d_model; DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, generic_param_->gpu_id); d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);
if (dmat->PageExists<SparsePage>()) { if (dmat->PageExists<SparsePage>()) {
size_t batch_offset = 0; size_t batch_offset = 0;
@ -709,10 +709,10 @@ class GPUPredictor : public xgboost::Predictor {
} else { } else {
size_t batch_offset = 0; size_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>()) { for (auto const& page : dmat->GetBatches<EllpackPage>()) {
dmat->Info().feature_types.SetDevice(generic_param_->gpu_id); dmat->Info().feature_types.SetDevice(ctx_->gpu_id);
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal( this->PredictInternal(
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types), page.Impl()->GetDeviceAccessor(ctx_->gpu_id, feature_types),
d_model, d_model,
out_preds, out_preds,
batch_offset); batch_offset);
@ -726,15 +726,15 @@ class GPUPredictor : public xgboost::Predictor {
Predictor::Predictor{generic_param} {} Predictor::Predictor{generic_param} {}
~GPUPredictor() override { ~GPUPredictor() override {
if (generic_param_->gpu_id >= 0 && generic_param_->gpu_id < common::AllVisibleGPUs()) { if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
} }
} }
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, uint32_t tree_begin, const gbm::GBTreeModel& model, uint32_t tree_begin,
uint32_t tree_end = 0) const override { uint32_t tree_end = 0) const override {
int device = generic_param_->gpu_id; int device = ctx_->gpu_id;
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
auto* out_preds = &predts->predictions; auto* out_preds = &predts->predictions;
if (tree_end == 0) { if (tree_end == 0) {
@ -754,7 +754,7 @@ class GPUPredictor : public xgboost::Predictor {
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model."; << "Number of columns in data must equal to trained model.";
CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx()) CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " << "XGBoost is running on device: " << this->ctx_->gpu_id << ", "
<< "but data is on: " << m->DeviceIdx(); << "but data is on: " << m->DeviceIdx();
if (p_m) { if (p_m) {
p_m->Info().num_row_ = m->NumRows(); p_m->Info().num_row_ = m->NumRows();
@ -821,8 +821,8 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) { if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented; LOG(FATAL) << "Dart booster feature " << not_implemented;
} }
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id); out_contribs->SetDevice(ctx_->gpu_id);
if (tree_end == 0 || tree_end > model.trees.size()) { if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size()); tree_end = static_cast<uint32_t>(model.trees.size());
} }
@ -840,12 +840,12 @@ class GPUPredictor : public xgboost::Predictor {
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths; device_paths;
DeviceModel d_model; DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id); d_model.Init(model, 0, tree_end, ctx_->gpu_id);
dh::device_vector<uint32_t> categories; dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id); ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) { for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(ctx_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(ctx_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature); model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
@ -854,7 +854,7 @@ class GPUPredictor : public xgboost::Predictor {
dh::tend(phis)); dh::tend(phis));
} }
// Add the base margin term to last column // Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
float base_score = model.learner_model_param->base_score; float base_score = model.learner_model_param->base_score;
dh::LaunchN( dh::LaunchN(
@ -879,8 +879,8 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) { if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented; LOG(FATAL) << "Dart booster feature " << not_implemented;
} }
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id); out_contribs->SetDevice(ctx_->gpu_id);
if (tree_end == 0 || tree_end > model.trees.size()) { if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size()); tree_end = static_cast<uint32_t>(model.trees.size());
} }
@ -899,12 +899,12 @@ class GPUPredictor : public xgboost::Predictor {
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths; device_paths;
DeviceModel d_model; DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id); d_model.Init(model, 0, tree_end, ctx_->gpu_id);
dh::device_vector<uint32_t> categories; dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id); ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) { for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(ctx_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(ctx_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature); model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
@ -913,7 +913,7 @@ class GPUPredictor : public xgboost::Predictor {
dh::tend(phis)); dh::tend(phis));
} }
// Add the base margin term to last column // Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
float base_score = model.learner_model_param->base_score; float base_score = model.learner_model_param->base_score;
size_t n_features = model.learner_model_param->num_feature; size_t n_features = model.learner_model_param->num_feature;
@ -938,8 +938,8 @@ class GPUPredictor : public xgboost::Predictor {
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions, void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
const gbm::GBTreeModel &model, const gbm::GBTreeModel &model,
unsigned tree_end) const override { unsigned tree_end) const override {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128; constexpr uint32_t kBlockThreads = 128;
@ -953,15 +953,15 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_end == 0 || tree_end > model.trees.size()) { if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size()); tree_end = static_cast<uint32_t>(model.trees.size());
} }
predictions->SetDevice(generic_param_->gpu_id); predictions->SetDevice(ctx_->gpu_id);
predictions->Resize(num_rows * tree_end); predictions->Resize(num_rows * tree_end);
DeviceModel d_model; DeviceModel d_model;
d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id); d_model.Init(model, 0, tree_end, this->ctx_->gpu_id);
if (p_fmat->PageExists<SparsePage>()) { if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) { for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(ctx_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(ctx_->gpu_id);
bst_row_t batch_offset = 0; bst_row_t batch_offset = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(), SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature}; model.learner_model_param->num_feature};
@ -986,7 +986,7 @@ class GPUPredictor : public xgboost::Predictor {
} else { } else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) { for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
bst_row_t batch_offset = 0; bst_row_t batch_offset = 0;
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)}; EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)};
size_t num_rows = batch.Size(); size_t num_rows = batch.Size();
auto grid = auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads)); static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));

View File

@ -77,8 +77,8 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
size_t n_classes = model.learner_model_param->num_output_group; size_t n_classes = model.learner_model_param->num_output_group;
size_t n = n_classes * info.num_row_; size_t n = n_classes * info.num_row_;
const HostDeviceVector<bst_float>* base_margin = info.base_margin_.Data(); const HostDeviceVector<bst_float>* base_margin = info.base_margin_.Data();
if (generic_param_->gpu_id >= 0) { if (ctx_->gpu_id >= 0) {
out_preds->SetDevice(generic_param_->gpu_id); out_preds->SetDevice(ctx_->gpu_id);
} }
if (base_margin->Size() != 0) { if (base_margin->Size() != 0) {
out_preds->Resize(n); out_preds->Resize(n);

View File

@ -217,10 +217,9 @@ void TestCategoricalPrediction(std::string name) {
gbm::GBTreeModel model(&param); gbm::GBTreeModel model(&param);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
GenericParameter runtime; GenericParameter ctx;
runtime.gpu_id = 0; ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
std::unique_ptr<Predictor> predictor{ std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
Predictor::Create(name.c_str(), &runtime)};
std::vector<float> row(kCols); std::vector<float> row(kCols);
row[split_ind] = split_cat; row[split_ind] = split_cat;