diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index 9301f10ce..dc76b7c3a 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -24,7 +24,7 @@ struct GenericParameter : public XGBoostParameter { bool seed_per_iteration; // number of threads to use if OpenMP is enabled // if equals 0, use system default - int nthread; + int nthread{0}; // primary device, -1 means no gpu. int gpu_id; // fail when gpu_id is invalid diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 04f218c1e..506392261 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -105,11 +105,11 @@ class Predictor { /* * \brief Runtime parameters. */ - GenericParameter const* generic_param_; + GenericParameter const* ctx_; public: - explicit Predictor(GenericParameter const* generic_param) : - generic_param_{generic_param} {} + explicit Predictor(GenericParameter const* ctx) : ctx_{ctx} {} + virtual ~Predictor() = default; /** diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 92797235d..892c95631 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -150,10 +150,12 @@ class AdapterView { static size_t constexpr kUnroll = kUnrollLen; public: - explicit AdapterView(Adapter *adapter, float missing, - common::Span workplace) - : adapter_{adapter}, missing_{missing}, workspace_{workplace}, - current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {} + explicit AdapterView(Adapter *adapter, float missing, common::Span workplace, + int32_t n_threads) + : adapter_{adapter}, + missing_{missing}, + workspace_{workplace}, + current_unroll_(n_threads > 0 ? n_threads : 1, 0) {} SparsePage::Inst operator[](size_t i) { bst_feature_t columns = adapter_->NumColumns(); auto const &batch = adapter_->Value(); @@ -186,7 +188,7 @@ template void PredictBatchByBlockOfRowsKernel( DataView batch, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end, - std::vector *p_thread_temp) { + std::vector *p_thread_temp, int32_t n_threads) { auto &thread_temp = *p_thread_temp; int32_t const num_group = model.learner_model_param->num_output_group; @@ -197,7 +199,7 @@ void PredictBatchByBlockOfRowsKernel( const int num_feature = model.learner_model_param->num_feature; omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size); - common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) { + common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) { const size_t batch_offset = block_id * block_of_rows_size; const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size); @@ -252,7 +254,7 @@ class CPUPredictor : public Predictor { void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { - const int threads = omp_get_max_threads(); + auto const n_threads = this->ctx_->Threads(); constexpr double kDensityThresh = .5; size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_, static_cast(1)); @@ -261,7 +263,7 @@ class CPUPredictor : public Predictor { bool blocked = density > kDensityThresh; std::vector feat_vecs; - InitThreadTemp(threads * (blocked ? kBlockOfRowsSize : 1), + InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), model.learner_model_param->num_feature, &feat_vecs); for (auto const &batch : p_fmat->GetBatches()) { CHECK_EQ(out_preds->size(), @@ -269,15 +271,14 @@ class CPUPredictor : public Predictor { model.learner_model_param->num_output_group); size_t constexpr kUnroll = 8; if (blocked) { - PredictBatchByBlockOfRowsKernel, - kBlockOfRowsSize>( - SparsePageView{&batch}, out_preds, model, tree_begin, - tree_end, &feat_vecs); + PredictBatchByBlockOfRowsKernel, kBlockOfRowsSize>( + SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, + n_threads); } else { PredictBatchByBlockOfRowsKernel, 1>( - SparsePageView{&batch}, out_preds, model, tree_begin, - tree_end, &feat_vecs); + SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, + n_threads); } } } @@ -304,7 +305,7 @@ class CPUPredictor : public Predictor { const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, uint32_t tree_end) const { - auto threads = omp_get_max_threads(); + auto const n_threads = this->ctx_->Threads(); auto m = dmlc::get>(x); CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; @@ -316,14 +317,14 @@ class CPUPredictor : public Predictor { info.num_row_ = m->NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); } - std::vector workspace(m->NumColumns() * 8 * threads); + std::vector workspace(m->NumColumns() * 8 * n_threads); auto &predictions = out_preds->predictions.HostVector(); std::vector thread_temp; - InitThreadTemp(threads * kBlockSize, model.learner_model_param->num_feature, + InitThreadTemp(n_threads * kBlockSize, model.learner_model_param->num_feature, &thread_temp); PredictBatchByBlockOfRowsKernel, kBlockSize>( - AdapterView(m.get(), missing, common::Span{workspace}), - &predictions, model, tree_begin, tree_end, &thread_temp); + AdapterView(m.get(), missing, common::Span{workspace}, n_threads), + &predictions, model, tree_begin, tree_end, &thread_temp, n_threads); } bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, @@ -370,10 +371,10 @@ class CPUPredictor : public Predictor { void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) const override { - const int nthread = omp_get_max_threads(); + auto const n_threads = this->ctx_->Threads(); std::vector feat_vecs; const int num_feature = model.learner_model_param->num_feature; - InitThreadTemp(nthread, num_feature, &feat_vecs); + InitThreadTemp(n_threads, num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { @@ -386,7 +387,7 @@ class CPUPredictor : public Predictor { // parallel over local batch auto page = batch.GetView(); const auto nsize = static_cast(batch.Size()); - common::ParallelFor(nsize, [&](bst_omp_uint i) { + common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { const int tid = omp_get_thread_num(); auto ridx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = feat_vecs[tid]; @@ -411,10 +412,10 @@ class CPUPredictor : public Predictor { std::vector const *tree_weights, bool approximate, int condition, unsigned condition_feature) const override { - const int nthread = omp_get_max_threads(); + auto const n_threads = this->ctx_->Threads(); const int num_feature = model.learner_model_param->num_feature; std::vector feat_vecs; - InitThreadTemp(nthread, num_feature, &feat_vecs); + InitThreadTemp(n_threads, num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { @@ -432,7 +433,7 @@ class CPUPredictor : public Predictor { std::fill(contribs.begin(), contribs.end(), 0); // initialize tree node mean values std::vector> mean_values(ntree_limit); - common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) { + common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); }); auto base_margin = info.base_margin_.View(GenericParameter::kCpuId); @@ -441,7 +442,7 @@ class CPUPredictor : public Predictor { auto page = batch.GetView(); // parallel over local batch const auto nsize = static_cast(batch.Size()); - common::ParallelFor(nsize, [&](bst_omp_uint i) { + common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { auto row_idx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = feat_vecs[omp_get_thread_num()]; if (feats.Size() == 0) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 71724f952..5c61fafa0 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -633,12 +633,12 @@ class GPUPredictor : public xgboost::Predictor { size_t num_features, HostDeviceVector* predictions, size_t batch_offset, bool is_dense) const { - batch.offset.SetDevice(generic_param_->gpu_id); - batch.data.SetDevice(generic_param_->gpu_id); + batch.offset.SetDevice(ctx_->gpu_id); + batch.data.SetDevice(ctx_->gpu_id); const uint32_t BLOCK_THREADS = 128; size_t num_rows = batch.Size(); auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); + auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id); size_t shared_memory_bytes = SharedMemoryBytes(num_features, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; @@ -694,10 +694,10 @@ class GPUPredictor : public xgboost::Predictor { if (tree_end - tree_begin == 0) { return; } - out_preds->SetDevice(generic_param_->gpu_id); + out_preds->SetDevice(ctx_->gpu_id); auto const& info = dmat->Info(); DeviceModel d_model; - d_model.Init(model, tree_begin, tree_end, generic_param_->gpu_id); + d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id); if (dmat->PageExists()) { size_t batch_offset = 0; @@ -709,10 +709,10 @@ class GPUPredictor : public xgboost::Predictor { } else { size_t batch_offset = 0; for (auto const& page : dmat->GetBatches()) { - dmat->Info().feature_types.SetDevice(generic_param_->gpu_id); + dmat->Info().feature_types.SetDevice(ctx_->gpu_id); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); this->PredictInternal( - page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types), + page.Impl()->GetDeviceAccessor(ctx_->gpu_id, feature_types), d_model, out_preds, batch_offset); @@ -726,15 +726,15 @@ class GPUPredictor : public xgboost::Predictor { Predictor::Predictor{generic_param} {} ~GPUPredictor() override { - if (generic_param_->gpu_id >= 0 && generic_param_->gpu_id < common::AllVisibleGPUs()) { - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); + if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) { + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); } } void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, uint32_t tree_begin, uint32_t tree_end = 0) const override { - int device = generic_param_->gpu_id; + int device = ctx_->gpu_id; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; auto* out_preds = &predts->predictions; if (tree_end == 0) { @@ -754,7 +754,7 @@ class GPUPredictor : public xgboost::Predictor { CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx()) - << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " + << "XGBoost is running on device: " << this->ctx_->gpu_id << ", " << "but data is on: " << m->DeviceIdx(); if (p_m) { p_m->Info().num_row_ = m->NumRows(); @@ -821,8 +821,8 @@ class GPUPredictor : public xgboost::Predictor { if (tree_weights != nullptr) { LOG(FATAL) << "Dart booster feature " << not_implemented; } - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); - out_contribs->SetDevice(generic_param_->gpu_id); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + out_contribs->SetDevice(ctx_->gpu_id); if (tree_end == 0 || tree_end > model.trees.size()) { tree_end = static_cast(model.trees.size()); } @@ -840,12 +840,12 @@ class GPUPredictor : public xgboost::Predictor { dh::device_vector> device_paths; DeviceModel d_model; - d_model.Init(model, 0, tree_end, generic_param_->gpu_id); + d_model.Init(model, 0, tree_end, ctx_->gpu_id); dh::device_vector categories; - ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id); + ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { - batch.data.SetDevice(generic_param_->gpu_id); - batch.offset.SetDevice(generic_param_->gpu_id); + batch.data.SetDevice(ctx_->gpu_id); + batch.offset.SetDevice(ctx_->gpu_id); SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature); auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; @@ -854,7 +854,7 @@ class GPUPredictor : public xgboost::Predictor { dh::tend(phis)); } // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); + p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; dh::LaunchN( @@ -879,8 +879,8 @@ class GPUPredictor : public xgboost::Predictor { if (tree_weights != nullptr) { LOG(FATAL) << "Dart booster feature " << not_implemented; } - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); - out_contribs->SetDevice(generic_param_->gpu_id); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + out_contribs->SetDevice(ctx_->gpu_id); if (tree_end == 0 || tree_end > model.trees.size()) { tree_end = static_cast(model.trees.size()); } @@ -899,12 +899,12 @@ class GPUPredictor : public xgboost::Predictor { dh::device_vector> device_paths; DeviceModel d_model; - d_model.Init(model, 0, tree_end, generic_param_->gpu_id); + d_model.Init(model, 0, tree_end, ctx_->gpu_id); dh::device_vector categories; - ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id); + ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { - batch.data.SetDevice(generic_param_->gpu_id); - batch.offset.SetDevice(generic_param_->gpu_id); + batch.data.SetDevice(ctx_->gpu_id); + batch.offset.SetDevice(ctx_->gpu_id); SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature); auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; @@ -913,7 +913,7 @@ class GPUPredictor : public xgboost::Predictor { dh::tend(phis)); } // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); + p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; size_t n_features = model.learner_model_param->num_feature; @@ -938,8 +938,8 @@ class GPUPredictor : public xgboost::Predictor { void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *predictions, const gbm::GBTreeModel &model, unsigned tree_end) const override { - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); - auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id); const MetaInfo& info = p_fmat->Info(); constexpr uint32_t kBlockThreads = 128; @@ -953,15 +953,15 @@ class GPUPredictor : public xgboost::Predictor { if (tree_end == 0 || tree_end > model.trees.size()) { tree_end = static_cast(model.trees.size()); } - predictions->SetDevice(generic_param_->gpu_id); + predictions->SetDevice(ctx_->gpu_id); predictions->Resize(num_rows * tree_end); DeviceModel d_model; - d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id); + d_model.Init(model, 0, tree_end, this->ctx_->gpu_id); if (p_fmat->PageExists()) { for (auto const& batch : p_fmat->GetBatches()) { - batch.data.SetDevice(generic_param_->gpu_id); - batch.offset.SetDevice(generic_param_->gpu_id); + batch.data.SetDevice(ctx_->gpu_id); + batch.offset.SetDevice(ctx_->gpu_id); bst_row_t batch_offset = 0; SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature}; @@ -986,7 +986,7 @@ class GPUPredictor : public xgboost::Predictor { } else { for (auto const& batch : p_fmat->GetBatches()) { bst_row_t batch_offset = 0; - EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)}; + EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)}; size_t num_rows = batch.Size(); auto grid = static_cast(common::DivRoundUp(num_rows, kBlockThreads)); diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 284d3b599..c4eb5d9db 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -77,8 +77,8 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVectornum_output_group; size_t n = n_classes * info.num_row_; const HostDeviceVector* base_margin = info.base_margin_.Data(); - if (generic_param_->gpu_id >= 0) { - out_preds->SetDevice(generic_param_->gpu_id); + if (ctx_->gpu_id >= 0) { + out_preds->SetDevice(ctx_->gpu_id); } if (base_margin->Size() != 0) { out_preds->Resize(n); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index b277da7d6..607741ca3 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -217,10 +217,9 @@ void TestCategoricalPrediction(std::string name) { gbm::GBTreeModel model(¶m); GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); - GenericParameter runtime; - runtime.gpu_id = 0; - std::unique_ptr predictor{ - Predictor::Create(name.c_str(), &runtime)}; + GenericParameter ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + std::unique_ptr predictor{Predictor::Create(name.c_str(), &ctx)}; std::vector row(kCols); row[split_ind] = split_cat;