Remove omp_get_max_threads in CPU predictor. (#7519)
This is part of the on going effort to remove the dependency on global omp variables.
This commit is contained in:
parent
5516281881
commit
68cdbc9c16
@ -24,7 +24,7 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
bool seed_per_iteration;
|
||||
// number of threads to use if OpenMP is enabled
|
||||
// if equals 0, use system default
|
||||
int nthread;
|
||||
int nthread{0};
|
||||
// primary device, -1 means no gpu.
|
||||
int gpu_id;
|
||||
// fail when gpu_id is invalid
|
||||
|
||||
@ -105,11 +105,11 @@ class Predictor {
|
||||
/*
|
||||
* \brief Runtime parameters.
|
||||
*/
|
||||
GenericParameter const* generic_param_;
|
||||
GenericParameter const* ctx_;
|
||||
|
||||
public:
|
||||
explicit Predictor(GenericParameter const* generic_param) :
|
||||
generic_param_{generic_param} {}
|
||||
explicit Predictor(GenericParameter const* ctx) : ctx_{ctx} {}
|
||||
|
||||
virtual ~Predictor() = default;
|
||||
|
||||
/**
|
||||
|
||||
@ -150,10 +150,12 @@ class AdapterView {
|
||||
static size_t constexpr kUnroll = kUnrollLen;
|
||||
|
||||
public:
|
||||
explicit AdapterView(Adapter *adapter, float missing,
|
||||
common::Span<Entry> workplace)
|
||||
: adapter_{adapter}, missing_{missing}, workspace_{workplace},
|
||||
current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {}
|
||||
explicit AdapterView(Adapter *adapter, float missing, common::Span<Entry> workplace,
|
||||
int32_t n_threads)
|
||||
: adapter_{adapter},
|
||||
missing_{missing},
|
||||
workspace_{workplace},
|
||||
current_unroll_(n_threads > 0 ? n_threads : 1, 0) {}
|
||||
SparsePage::Inst operator[](size_t i) {
|
||||
bst_feature_t columns = adapter_->NumColumns();
|
||||
auto const &batch = adapter_->Value();
|
||||
@ -186,7 +188,7 @@ template <typename DataView, size_t block_of_rows_size>
|
||||
void PredictBatchByBlockOfRowsKernel(
|
||||
DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads) {
|
||||
auto &thread_temp = *p_thread_temp;
|
||||
int32_t const num_group = model.learner_model_param->num_output_group;
|
||||
|
||||
@ -197,7 +199,7 @@ void PredictBatchByBlockOfRowsKernel(
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
|
||||
|
||||
common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
|
||||
common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) {
|
||||
const size_t batch_offset = block_id * block_of_rows_size;
|
||||
const size_t block_size =
|
||||
std::min(nsize - batch_offset, block_of_rows_size);
|
||||
@ -252,7 +254,7 @@ class CPUPredictor : public Predictor {
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) const {
|
||||
const int threads = omp_get_max_threads();
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
constexpr double kDensityThresh = .5;
|
||||
size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_,
|
||||
static_cast<uint64_t>(1));
|
||||
@ -261,7 +263,7 @@ class CPUPredictor : public Predictor {
|
||||
bool blocked = density > kDensityThresh;
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(threads * (blocked ? kBlockOfRowsSize : 1),
|
||||
InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1),
|
||||
model.learner_model_param->num_feature, &feat_vecs);
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
@ -269,15 +271,14 @@ class CPUPredictor : public Predictor {
|
||||
model.learner_model_param->num_output_group);
|
||||
size_t constexpr kUnroll = 8;
|
||||
if (blocked) {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
||||
kBlockOfRowsSize>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, kBlockOfRowsSize>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs,
|
||||
n_threads);
|
||||
|
||||
} else {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, 1>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs,
|
||||
n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -304,7 +305,7 @@ class CPUPredictor : public Predictor {
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
auto threads = omp_get_max_threads();
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
@ -316,14 +317,14 @@ class CPUPredictor : public Predictor {
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * n_threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads * kBlockSize, model.learner_model_param->num_feature,
|
||||
InitThreadTemp(n_threads * kBlockSize, model.learner_model_param->num_feature,
|
||||
&thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp, n_threads);
|
||||
}
|
||||
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
@ -370,10 +371,10 @@ class CPUPredictor : public Predictor {
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
InitThreadTemp(nthread, num_feature, &feat_vecs);
|
||||
InitThreadTemp(n_threads, num_feature, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
@ -386,7 +387,7 @@ class CPUPredictor : public Predictor {
|
||||
// parallel over local batch
|
||||
auto page = batch.GetView();
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
common::ParallelFor(nsize, [&](bst_omp_uint i) {
|
||||
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = feat_vecs[tid];
|
||||
@ -411,10 +412,10 @@ class CPUPredictor : public Predictor {
|
||||
std::vector<bst_float> const *tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) const override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(nthread, num_feature, &feat_vecs);
|
||||
InitThreadTemp(n_threads, num_feature, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
@ -432,7 +433,7 @@ class CPUPredictor : public Predictor {
|
||||
std::fill(contribs.begin(), contribs.end(), 0);
|
||||
// initialize tree node mean values
|
||||
std::vector<std::vector<float>> mean_values(ntree_limit);
|
||||
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) {
|
||||
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
|
||||
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
|
||||
});
|
||||
auto base_margin = info.base_margin_.View(GenericParameter::kCpuId);
|
||||
@ -441,7 +442,7 @@ class CPUPredictor : public Predictor {
|
||||
auto page = batch.GetView();
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
common::ParallelFor(nsize, [&](bst_omp_uint i) {
|
||||
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
|
||||
if (feats.Size() == 0) {
|
||||
|
||||
@ -633,12 +633,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset, bool is_dense) const {
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(ctx_->gpu_id);
|
||||
batch.data.SetDevice(ctx_->gpu_id);
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
|
||||
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
@ -694,10 +694,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
out_preds->SetDevice(ctx_->gpu_id);
|
||||
auto const& info = dmat->Info();
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
|
||||
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);
|
||||
|
||||
if (dmat->PageExists<SparsePage>()) {
|
||||
size_t batch_offset = 0;
|
||||
@ -709,10 +709,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
} else {
|
||||
size_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||
dmat->Info().feature_types.SetDevice(generic_param_->gpu_id);
|
||||
dmat->Info().feature_types.SetDevice(ctx_->gpu_id);
|
||||
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
this->PredictInternal(
|
||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types),
|
||||
page.Impl()->GetDeviceAccessor(ctx_->gpu_id, feature_types),
|
||||
d_model,
|
||||
out_preds,
|
||||
batch_offset);
|
||||
@ -726,15 +726,15 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
Predictor::Predictor{generic_param} {}
|
||||
|
||||
~GPUPredictor() override {
|
||||
if (generic_param_->gpu_id >= 0 && generic_param_->gpu_id < common::AllVisibleGPUs()) {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
}
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||
const gbm::GBTreeModel& model, uint32_t tree_begin,
|
||||
uint32_t tree_end = 0) const override {
|
||||
int device = generic_param_->gpu_id;
|
||||
int device = ctx_->gpu_id;
|
||||
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
||||
auto* out_preds = &predts->predictions;
|
||||
if (tree_end == 0) {
|
||||
@ -754,7 +754,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
|
||||
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
||||
<< "XGBoost is running on device: " << this->ctx_->gpu_id << ", "
|
||||
<< "but data is on: " << m->DeviceIdx();
|
||||
if (p_m) {
|
||||
p_m->Info().num_row_ = m->NumRows();
|
||||
@ -821,8 +821,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (tree_weights != nullptr) {
|
||||
LOG(FATAL) << "Dart booster feature " << not_implemented;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
out_contribs->SetDevice(ctx_->gpu_id);
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
@ -840,12 +840,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
|
||||
device_paths;
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
|
||||
d_model.Init(model, 0, tree_end, ctx_->gpu_id);
|
||||
dh::device_vector<uint32_t> categories;
|
||||
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
|
||||
ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(ctx_->gpu_id);
|
||||
batch.offset.SetDevice(ctx_->gpu_id);
|
||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature);
|
||||
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
|
||||
@ -854,7 +854,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::tend(phis));
|
||||
}
|
||||
// Add the base margin term to last column
|
||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||
p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
|
||||
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
|
||||
float base_score = model.learner_model_param->base_score;
|
||||
dh::LaunchN(
|
||||
@ -879,8 +879,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (tree_weights != nullptr) {
|
||||
LOG(FATAL) << "Dart booster feature " << not_implemented;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
out_contribs->SetDevice(ctx_->gpu_id);
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
@ -899,12 +899,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
|
||||
device_paths;
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
|
||||
d_model.Init(model, 0, tree_end, ctx_->gpu_id);
|
||||
dh::device_vector<uint32_t> categories;
|
||||
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
|
||||
ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(ctx_->gpu_id);
|
||||
batch.offset.SetDevice(ctx_->gpu_id);
|
||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature);
|
||||
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
|
||||
@ -913,7 +913,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::tend(phis));
|
||||
}
|
||||
// Add the base margin term to last column
|
||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||
p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
|
||||
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
|
||||
float base_score = model.learner_model_param->base_score;
|
||||
size_t n_features = model.learner_model_param->num_feature;
|
||||
@ -938,8 +938,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
|
||||
const gbm::GBTreeModel &model,
|
||||
unsigned tree_end) const override {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
|
||||
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
constexpr uint32_t kBlockThreads = 128;
|
||||
@ -953,15 +953,15 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
predictions->SetDevice(generic_param_->gpu_id);
|
||||
predictions->SetDevice(ctx_->gpu_id);
|
||||
predictions->Resize(num_rows * tree_end);
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id);
|
||||
d_model.Init(model, 0, tree_end, this->ctx_->gpu_id);
|
||||
|
||||
if (p_fmat->PageExists<SparsePage>()) {
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(ctx_->gpu_id);
|
||||
batch.offset.SetDevice(ctx_->gpu_id);
|
||||
bst_row_t batch_offset = 0;
|
||||
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature};
|
||||
@ -986,7 +986,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
} else {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
|
||||
bst_row_t batch_offset = 0;
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)};
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)};
|
||||
size_t num_rows = batch.Size();
|
||||
auto grid =
|
||||
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||
|
||||
@ -77,8 +77,8 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
|
||||
size_t n_classes = model.learner_model_param->num_output_group;
|
||||
size_t n = n_classes * info.num_row_;
|
||||
const HostDeviceVector<bst_float>* base_margin = info.base_margin_.Data();
|
||||
if (generic_param_->gpu_id >= 0) {
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
if (ctx_->gpu_id >= 0) {
|
||||
out_preds->SetDevice(ctx_->gpu_id);
|
||||
}
|
||||
if (base_margin->Size() != 0) {
|
||||
out_preds->Resize(n);
|
||||
|
||||
@ -217,10 +217,9 @@ void TestCategoricalPrediction(std::string name) {
|
||||
gbm::GBTreeModel model(¶m);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
GenericParameter runtime;
|
||||
runtime.gpu_id = 0;
|
||||
std::unique_ptr<Predictor> predictor{
|
||||
Predictor::Create(name.c_str(), &runtime)};
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
|
||||
|
||||
std::vector<float> row(kCols);
|
||||
row[split_ind] = split_cat;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user