Make prediction functions thread safe. (#6648)
This commit is contained in:
parent
0f2ed21a9d
commit
c3c8e66fc9
@ -132,7 +132,7 @@ class Predictor {
|
|||||||
*/
|
*/
|
||||||
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
|
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
|
||||||
const gbm::GBTreeModel& model, int tree_begin,
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
uint32_t const ntree_limit = 0) = 0;
|
uint32_t const ntree_limit = 0) const = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Inplace prediction.
|
* \brief Inplace prediction.
|
||||||
@ -161,7 +161,7 @@ class Predictor {
|
|||||||
virtual void PredictInstance(const SparsePage::Inst& inst,
|
virtual void PredictInstance(const SparsePage::Inst& inst,
|
||||||
std::vector<bst_float>* out_preds,
|
std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) const = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief predict the leaf index of each tree, the output will be nsample *
|
* \brief predict the leaf index of each tree, the output will be nsample *
|
||||||
@ -175,7 +175,7 @@ class Predictor {
|
|||||||
|
|
||||||
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) const = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat,
|
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat,
|
||||||
@ -203,14 +203,14 @@ class Predictor {
|
|||||||
std::vector<bst_float>* tree_weights = nullptr,
|
std::vector<bst_float>* tree_weights = nullptr,
|
||||||
bool approximate = false,
|
bool approximate = false,
|
||||||
int condition = 0,
|
int condition = 0,
|
||||||
unsigned condition_feature = 0) = 0;
|
unsigned condition_feature = 0) const = 0;
|
||||||
|
|
||||||
virtual void PredictInteractionContributions(DMatrix* dmat,
|
virtual void PredictInteractionContributions(DMatrix* dmat,
|
||||||
HostDeviceVector<bst_float>* out_contribs,
|
HostDeviceVector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned ntree_limit = 0,
|
unsigned ntree_limit = 0,
|
||||||
std::vector<bst_float>* tree_weights = nullptr,
|
std::vector<bst_float>* tree_weights = nullptr,
|
||||||
bool approximate = false) = 0;
|
bool approximate = false) const = 0;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -183,11 +183,11 @@ class CPUPredictor : public Predictor {
|
|||||||
|
|
||||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||||
int32_t tree_end) {
|
int32_t tree_end) const {
|
||||||
std::lock_guard<std::mutex> guard(lock_);
|
|
||||||
const int threads = omp_get_max_threads();
|
const int threads = omp_get_max_threads();
|
||||||
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
|
std::vector<RegTree::FVec> feat_vecs;
|
||||||
&this->thread_temp_);
|
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||||
|
model.learner_model_param->num_feature, &feat_vecs);
|
||||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
CHECK_EQ(out_preds->size(),
|
CHECK_EQ(out_preds->size(),
|
||||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
|
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
|
||||||
@ -195,7 +195,7 @@ class CPUPredictor : public Predictor {
|
|||||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
||||||
kBlockOfRowsSize>(SparsePageView<kUnroll>{&batch},
|
kBlockOfRowsSize>(SparsePageView<kUnroll>{&batch},
|
||||||
out_preds, model, tree_begin,
|
out_preds, model, tree_begin,
|
||||||
tree_end, &thread_temp_);
|
tree_end, &feat_vecs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,7 +238,7 @@ class CPUPredictor : public Predictor {
|
|||||||
// multi-output and forest. Same problem exists for tree_begin
|
// multi-output and forest. Same problem exists for tree_begin
|
||||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||||
const gbm::GBTreeModel& model, int tree_begin,
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
uint32_t const ntree_limit = 0) override {
|
uint32_t const ntree_limit = 0) const override {
|
||||||
// tree_begin is not used, right now we just enforce it to be 0.
|
// tree_begin is not used, right now we just enforce it to be 0.
|
||||||
CHECK_EQ(tree_begin, 0);
|
CHECK_EQ(tree_begin, 0);
|
||||||
auto* out_preds = &predts->predictions;
|
auto* out_preds = &predts->predictions;
|
||||||
@ -326,11 +326,10 @@ class CPUPredictor : public Predictor {
|
|||||||
|
|
||||||
void PredictInstance(const SparsePage::Inst& inst,
|
void PredictInstance(const SparsePage::Inst& inst,
|
||||||
std::vector<bst_float>* out_preds,
|
std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||||
if (thread_temp_.size() == 0) {
|
std::vector<RegTree::FVec> feat_vecs;
|
||||||
thread_temp_.resize(1, RegTree::FVec());
|
feat_vecs.resize(1, RegTree::FVec());
|
||||||
thread_temp_[0].Init(model.learner_model_param->num_feature);
|
feat_vecs[0].Init(model.learner_model_param->num_feature);
|
||||||
}
|
|
||||||
ntree_limit *= model.learner_model_param->num_output_group;
|
ntree_limit *= model.learner_model_param->num_output_group;
|
||||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||||
@ -340,15 +339,16 @@ class CPUPredictor : public Predictor {
|
|||||||
// loop over output groups
|
// loop over output groups
|
||||||
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
|
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
|
||||||
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid,
|
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid,
|
||||||
&thread_temp_[0], 0, ntree_limit) +
|
&feat_vecs[0], 0, ntree_limit) +
|
||||||
model.learner_model_param->base_score;
|
model.learner_model_param->base_score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
|
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||||
const int nthread = omp_get_max_threads();
|
const int nthread = omp_get_max_threads();
|
||||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
|
std::vector<RegTree::FVec> feat_vecs;
|
||||||
|
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
// number of valid trees
|
// number of valid trees
|
||||||
ntree_limit *= model.learner_model_param->num_output_group;
|
ntree_limit *= model.learner_model_param->num_output_group;
|
||||||
@ -366,7 +366,7 @@ class CPUPredictor : public Predictor {
|
|||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||||
RegTree::FVec &feats = thread_temp_[tid];
|
RegTree::FVec &feats = feat_vecs[tid];
|
||||||
feats.Fill(page[i]);
|
feats.Fill(page[i]);
|
||||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||||
@ -381,9 +381,10 @@ class CPUPredictor : public Predictor {
|
|||||||
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
||||||
std::vector<bst_float>* tree_weights,
|
std::vector<bst_float>* tree_weights,
|
||||||
bool approximate, int condition,
|
bool approximate, int condition,
|
||||||
unsigned condition_feature) override {
|
unsigned condition_feature) const override {
|
||||||
const int nthread = omp_get_max_threads();
|
const int nthread = omp_get_max_threads();
|
||||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
|
std::vector<RegTree::FVec> feat_vecs;
|
||||||
|
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
// number of valid trees
|
// number of valid trees
|
||||||
ntree_limit *= model.learner_model_param->num_output_group;
|
ntree_limit *= model.learner_model_param->num_output_group;
|
||||||
@ -414,7 +415,7 @@ class CPUPredictor : public Predictor {
|
|||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||||
RegTree::FVec &feats = thread_temp_[omp_get_thread_num()];
|
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
|
||||||
std::vector<bst_float> this_tree_contribs(ncolumns);
|
std::vector<bst_float> this_tree_contribs(ncolumns);
|
||||||
// loop over all classes
|
// loop over all classes
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
@ -452,7 +453,7 @@ class CPUPredictor : public Predictor {
|
|||||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||||
std::vector<bst_float>* tree_weights,
|
std::vector<bst_float>* tree_weights,
|
||||||
bool approximate) override {
|
bool approximate) const override {
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
const int ngroup = model.learner_model_param->num_output_group;
|
const int ngroup = model.learner_model_param->num_output_group;
|
||||||
size_t const ncolumns = model.learner_model_param->num_feature;
|
size_t const ncolumns = model.learner_model_param->num_feature;
|
||||||
@ -501,8 +502,6 @@ class CPUPredictor : public Predictor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::mutex lock_;
|
|
||||||
std::vector<RegTree::FVec> thread_temp_;
|
|
||||||
static size_t constexpr kBlockOfRowsSize = 64;
|
static size_t constexpr kBlockOfRowsSize = 64;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -501,17 +501,18 @@ size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
|
|||||||
class GPUPredictor : public xgboost::Predictor {
|
class GPUPredictor : public xgboost::Predictor {
|
||||||
private:
|
private:
|
||||||
void PredictInternal(const SparsePage& batch,
|
void PredictInternal(const SparsePage& batch,
|
||||||
|
DeviceModel const& model,
|
||||||
size_t num_features,
|
size_t num_features,
|
||||||
HostDeviceVector<bst_float>* predictions,
|
HostDeviceVector<bst_float>* predictions,
|
||||||
size_t batch_offset) {
|
size_t batch_offset) const {
|
||||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||||
batch.data.SetDevice(generic_param_->gpu_id);
|
batch.data.SetDevice(generic_param_->gpu_id);
|
||||||
const uint32_t BLOCK_THREADS = 128;
|
const uint32_t BLOCK_THREADS = 128;
|
||||||
size_t num_rows = batch.Size();
|
size_t num_rows = batch.Size();
|
||||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
|
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
|
||||||
size_t shared_memory_bytes =
|
size_t shared_memory_bytes =
|
||||||
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes_);
|
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes);
|
||||||
bool use_shared = shared_memory_bytes != 0;
|
bool use_shared = shared_memory_bytes != 0;
|
||||||
|
|
||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
@ -519,18 +520,19 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
num_features);
|
num_features);
|
||||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||||
PredictKernel<SparsePageLoader, SparsePageView>, data,
|
PredictKernel<SparsePageLoader, SparsePageView>, data,
|
||||||
model_.nodes.ConstDeviceSpan(),
|
model.nodes.ConstDeviceSpan(),
|
||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
|
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||||
model_.split_types.ConstDeviceSpan(),
|
model.split_types.ConstDeviceSpan(),
|
||||||
model_.categories_tree_segments.ConstDeviceSpan(),
|
model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
model_.categories_node_segments.ConstDeviceSpan(),
|
model.categories_node_segments.ConstDeviceSpan(),
|
||||||
model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_,
|
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||||
num_features, num_rows, entry_start, use_shared, model_.num_group);
|
num_features, num_rows, entry_start, use_shared, model.num_group);
|
||||||
}
|
}
|
||||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||||
|
DeviceModel const& model,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
size_t batch_offset) {
|
size_t batch_offset) const {
|
||||||
const uint32_t BLOCK_THREADS = 256;
|
const uint32_t BLOCK_THREADS = 256;
|
||||||
size_t num_rows = batch.n_rows;
|
size_t num_rows = batch.n_rows;
|
||||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
@ -539,31 +541,31 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch,
|
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch,
|
||||||
model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
|
model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
|
||||||
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
|
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||||
model_.split_types.ConstDeviceSpan(),
|
model.split_types.ConstDeviceSpan(),
|
||||||
model_.categories_tree_segments.ConstDeviceSpan(),
|
model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
model_.categories_node_segments.ConstDeviceSpan(),
|
model.categories_node_segments.ConstDeviceSpan(),
|
||||||
model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_,
|
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||||
batch.NumFeatures(), num_rows, entry_start, use_shared,
|
batch.NumFeatures(), num_rows, entry_start, use_shared,
|
||||||
model_.num_group);
|
model.num_group);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||||
size_t tree_end) {
|
size_t tree_end) const {
|
||||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
|
||||||
if (tree_end - tree_begin == 0) {
|
if (tree_end - tree_begin == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
|
|
||||||
out_preds->SetDevice(generic_param_->gpu_id);
|
out_preds->SetDevice(generic_param_->gpu_id);
|
||||||
auto const& info = dmat->Info();
|
auto const& info = dmat->Info();
|
||||||
|
DeviceModel d_model;
|
||||||
|
d_model.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
|
||||||
|
|
||||||
if (dmat->PageExists<SparsePage>()) {
|
if (dmat->PageExists<SparsePage>()) {
|
||||||
size_t batch_offset = 0;
|
size_t batch_offset = 0;
|
||||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||||
this->PredictInternal(batch, model.learner_model_param->num_feature,
|
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
|
||||||
out_preds, batch_offset);
|
out_preds, batch_offset);
|
||||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||||
}
|
}
|
||||||
@ -572,6 +574,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||||
this->PredictInternal(
|
this->PredictInternal(
|
||||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id),
|
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id),
|
||||||
|
d_model,
|
||||||
out_preds,
|
out_preds,
|
||||||
batch_offset);
|
batch_offset);
|
||||||
batch_offset += page.Impl()->n_rows;
|
batch_offset += page.Impl()->n_rows;
|
||||||
@ -591,10 +594,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
|
|
||||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||||
const gbm::GBTreeModel& model, int tree_begin,
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
unsigned ntree_limit = 0) override {
|
unsigned ntree_limit = 0) const override {
|
||||||
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
|
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
|
||||||
// FIXME(trivialfis): Remove the duplication.
|
// FIXME(trivialfis): Remove the duplication.
|
||||||
std::lock_guard<std::mutex> const guard(lock_);
|
|
||||||
int device = generic_param_->gpu_id;
|
int device = generic_param_->gpu_id;
|
||||||
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
||||||
ConfigureDevice(device);
|
ConfigureDevice(device);
|
||||||
@ -702,7 +704,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||||
std::vector<bst_float>*,
|
std::vector<bst_float>*,
|
||||||
bool approximate, int,
|
bool approximate, int,
|
||||||
unsigned) override {
|
unsigned) const override {
|
||||||
if (approximate) {
|
if (approximate) {
|
||||||
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
|
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
|
||||||
}
|
}
|
||||||
@ -755,7 +757,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned ntree_limit,
|
unsigned ntree_limit,
|
||||||
std::vector<bst_float>*,
|
std::vector<bst_float>*,
|
||||||
bool approximate) override {
|
bool approximate) const override {
|
||||||
if (approximate) {
|
if (approximate) {
|
||||||
LOG(FATAL) << "[Internal error]: " << __func__
|
LOG(FATAL) << "[Internal error]: " << __func__
|
||||||
<< " approximate is not implemented in GPU Predictor.";
|
<< " approximate is not implemented in GPU Predictor.";
|
||||||
@ -828,21 +830,21 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
|
|
||||||
void PredictInstance(const SparsePage::Inst&,
|
void PredictInstance(const SparsePage::Inst&,
|
||||||
std::vector<bst_float>*,
|
std::vector<bst_float>*,
|
||||||
const gbm::GBTreeModel&, unsigned) override {
|
const gbm::GBTreeModel&, unsigned) const override {
|
||||||
LOG(FATAL) << "[Internal error]: " << __func__
|
LOG(FATAL) << "[Internal error]: " << __func__
|
||||||
<< " is not implemented in GPU Predictor.";
|
<< " is not implemented in GPU Predictor.";
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
|
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned ntree_limit) override {
|
unsigned ntree_limit) const override {
|
||||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||||
ConfigureDevice(generic_param_->gpu_id);
|
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
|
||||||
|
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
constexpr uint32_t kBlockThreads = 128;
|
constexpr uint32_t kBlockThreads = 128;
|
||||||
size_t shared_memory_bytes =
|
size_t shared_memory_bytes =
|
||||||
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes_);
|
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes);
|
||||||
bool use_shared = shared_memory_bytes != 0;
|
bool use_shared = shared_memory_bytes != 0;
|
||||||
bst_feature_t num_features = info.num_col_;
|
bst_feature_t num_features = info.num_col_;
|
||||||
bst_row_t num_rows = info.num_row_;
|
bst_row_t num_rows = info.num_row_;
|
||||||
@ -854,7 +856,8 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
predictions->SetDevice(generic_param_->gpu_id);
|
predictions->SetDevice(generic_param_->gpu_id);
|
||||||
predictions->Resize(num_rows * real_ntree_limit);
|
predictions->Resize(num_rows * real_ntree_limit);
|
||||||
model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id);
|
DeviceModel d_model;
|
||||||
|
d_model.Init(model, 0, real_ntree_limit, this->generic_param_->gpu_id);
|
||||||
|
|
||||||
if (p_fmat->PageExists<SparsePage>()) {
|
if (p_fmat->PageExists<SparsePage>()) {
|
||||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
@ -868,10 +871,10 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
||||||
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
|
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
|
||||||
model_.nodes.ConstDeviceSpan(),
|
d_model.nodes.ConstDeviceSpan(),
|
||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
model_.tree_segments.ConstDeviceSpan(),
|
d_model.tree_segments.ConstDeviceSpan(),
|
||||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||||
entry_start, use_shared);
|
entry_start, use_shared);
|
||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
}
|
}
|
||||||
@ -884,10 +887,10 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
||||||
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
|
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
|
||||||
model_.nodes.ConstDeviceSpan(),
|
d_model.nodes.ConstDeviceSpan(),
|
||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
model_.tree_segments.ConstDeviceSpan(),
|
d_model.tree_segments.ConstDeviceSpan(),
|
||||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||||
entry_start, use_shared);
|
entry_start, use_shared);
|
||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
}
|
}
|
||||||
@ -900,15 +903,12 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief Reconfigure the device when GPU is changed. */
|
/*! \brief Reconfigure the device when GPU is changed. */
|
||||||
void ConfigureDevice(int device) {
|
static size_t ConfigureDevice(int device) {
|
||||||
if (device >= 0) {
|
if (device >= 0) {
|
||||||
max_shared_memory_bytes_ = dh::MaxSharedMemory(device);
|
return dh::MaxSharedMemory(device);
|
||||||
}
|
}
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::mutex lock_;
|
|
||||||
DeviceModel model_;
|
|
||||||
size_t max_shared_memory_bytes_ { 0 };
|
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||||
|
|||||||
@ -199,7 +199,7 @@ TEST(Learner, JsonModelIO) {
|
|||||||
// ```
|
// ```
|
||||||
TEST(Learner, MultiThreadedPredict) {
|
TEST(Learner, MultiThreadedPredict) {
|
||||||
size_t constexpr kRows = 1000;
|
size_t constexpr kRows = 1000;
|
||||||
size_t constexpr kCols = 1000;
|
size_t constexpr kCols = 100;
|
||||||
|
|
||||||
std::shared_ptr<DMatrix> p_dmat{
|
std::shared_ptr<DMatrix> p_dmat{
|
||||||
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
|
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
|
||||||
@ -219,8 +219,11 @@ TEST(Learner, MultiThreadedPredict) {
|
|||||||
threads.emplace_back([learner, p_data] {
|
threads.emplace_back([learner, p_data] {
|
||||||
size_t constexpr kIters = 10;
|
size_t constexpr kIters = 10;
|
||||||
auto &entry = learner->GetThreadLocal().prediction_entry;
|
auto &entry = learner->GetThreadLocal().prediction_entry;
|
||||||
|
HostDeviceVector<float> predictions;
|
||||||
for (size_t iter = 0; iter < kIters; ++iter) {
|
for (size_t iter = 0; iter < kIters; ++iter) {
|
||||||
learner->Predict(p_data, false, &entry.predictions);
|
learner->Predict(p_data, false, &entry.predictions);
|
||||||
|
learner->Predict(p_data, false, &predictions, 0, true); // leaf
|
||||||
|
learner->Predict(p_data, false, &predictions, 0, false, true); // contribs
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user