Thread safe, inplace prediction. (#5389)
Normal prediction with DMatrix is now thread safe with locks. Added inplace prediction is lock free thread safe. When data is on device (cupy, cudf), the returned data is also on device. * Implementation for numpy, csr, cudf and cupy. * Implementation for dask. * Remove sync in simple dmatrix.
This commit is contained in:
@@ -2,13 +2,22 @@
|
||||
* Copyright by Contributors 2017-2020
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/any.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "../data/adapter.h"
|
||||
#include "../common/math.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -16,89 +25,156 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
bst_float PredValue(const SparsePage::Inst &inst,
|
||||
const std::vector<std::unique_ptr<RegTree>> &trees,
|
||||
const std::vector<int> &tree_info, int bst_group,
|
||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += (*trees[i])[tid].LeafValue();
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
return psum;
|
||||
}
|
||||
|
||||
template <size_t kUnrollLen = 8>
|
||||
struct SparsePageView {
|
||||
SparsePage const* page;
|
||||
bst_row_t base_rowid;
|
||||
static size_t constexpr kUnroll = kUnrollLen;
|
||||
|
||||
explicit SparsePageView(SparsePage const *p)
|
||||
: page{p}, base_rowid{page->base_rowid} {
|
||||
// Pull to host before entering omp block, as this is not thread safe.
|
||||
page->data.HostVector();
|
||||
page->offset.HostVector();
|
||||
}
|
||||
SparsePage::Inst operator[](size_t i) { return (*page)[i]; }
|
||||
size_t Size() const { return page->Size(); }
|
||||
};
|
||||
|
||||
template <typename Adapter, size_t kUnrollLen = 8>
|
||||
class AdapterView {
|
||||
Adapter* adapter_;
|
||||
float missing_;
|
||||
common::Span<Entry> workspace_;
|
||||
std::vector<size_t> current_unroll_;
|
||||
|
||||
public:
|
||||
static size_t constexpr kUnroll = kUnrollLen;
|
||||
|
||||
public:
|
||||
explicit AdapterView(Adapter *adapter, float missing,
|
||||
common::Span<Entry> workplace)
|
||||
: adapter_{adapter}, missing_{missing}, workspace_{workplace},
|
||||
current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {}
|
||||
SparsePage::Inst operator[](size_t i) {
|
||||
bst_feature_t columns = adapter_->NumColumns();
|
||||
auto const &batch = adapter_->Value();
|
||||
auto row = batch.GetLine(i);
|
||||
auto t = omp_get_thread_num();
|
||||
auto const beg = (columns * kUnroll * t) + (current_unroll_[t] * columns);
|
||||
size_t non_missing {beg};
|
||||
for (size_t c = 0; c < row.Size(); ++c) {
|
||||
auto e = row.GetElement(c);
|
||||
if (missing_ != e.value && !common::CheckNAN(e.value)) {
|
||||
workspace_[non_missing] =
|
||||
Entry{static_cast<bst_feature_t>(e.column_idx), e.value};
|
||||
++non_missing;
|
||||
}
|
||||
}
|
||||
auto ret = workspace_.subspan(beg, non_missing - beg);
|
||||
current_unroll_[t]++;
|
||||
if (current_unroll_[t] == kUnroll) {
|
||||
current_unroll_[t] = 0;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t Size() const { return adapter_->NumRows(); }
|
||||
|
||||
bst_row_t const static base_rowid = 0; // NOLINT
|
||||
};
|
||||
|
||||
template <typename DataView>
|
||||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||
auto& thread_temp = *p_thread_temp;
|
||||
int32_t const num_group = model.learner_model_param_->num_output_group;
|
||||
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
auto constexpr kUnroll = DataView::kUnroll;
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec &feats = thread_temp[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (size_t k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (size_t k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (size_t k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] += PredValue(inst[k], model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||
RegTree::FVec &feats = thread_temp[0];
|
||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
auto inst = batch[i];
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] += PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
static bst_float PredValue(const SparsePage::Inst& inst,
|
||||
const std::vector<std::unique_ptr<RegTree>>& trees,
|
||||
const std::vector<int>& tree_info, int bst_group,
|
||||
RegTree::FVec* p_feats,
|
||||
unsigned tree_begin, unsigned tree_end) {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += (*trees[i])[tid].LeafValue();
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
return psum;
|
||||
}
|
||||
|
||||
// init thread buffers
|
||||
inline void InitThreadTemp(int nthread, int num_feature) {
|
||||
int prev_thread_temp_size = thread_temp.size();
|
||||
static void InitThreadTemp(int nthread, int num_feature, std::vector<RegTree::FVec>* out) {
|
||||
int prev_thread_temp_size = out->size();
|
||||
if (prev_thread_temp_size < nthread) {
|
||||
thread_temp.resize(nthread, RegTree::FVec());
|
||||
out->resize(nthread, RegTree::FVec());
|
||||
for (int i = prev_thread_temp_size; i < nthread; ++i) {
|
||||
thread_temp[i].Init(num_feature);
|
||||
(*out)[i].Init(num_feature);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PredInternal(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) {
|
||||
int32_t const num_group = model.learner_model_param_->num_output_group;
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature);
|
||||
std::vector<bst_float>& preds = *out_preds;
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
// Pull to host before entering omp block, as this is not thread safe.
|
||||
batch.data.HostVector();
|
||||
batch.offset.HostVector();
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] += this->PredValue(
|
||||
inst[k], model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||
RegTree::FVec& feats = thread_temp[0];
|
||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
auto inst = batch[i];
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
this->PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
const int threads = omp_get_max_threads();
|
||||
InitThreadTemp(threads, model.learner_model_param_->num_feature, &this->thread_temp_);
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model.learner_model_param_->num_output_group);
|
||||
size_t constexpr kUnroll = 8;
|
||||
PredictBatchKernel(SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &thread_temp_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,9 +251,9 @@ class CPUPredictor : public Predictor {
|
||||
CHECK_LE(beg_version, end_version);
|
||||
|
||||
if (beg_version < end_version) {
|
||||
this->PredInternal(dmat, &out_preds->HostVector(), model,
|
||||
beg_version * output_groups,
|
||||
end_version * output_groups);
|
||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model,
|
||||
beg_version * output_groups,
|
||||
end_version * output_groups);
|
||||
}
|
||||
|
||||
// delta means {size of forest} * {number of newly accumulated layers}
|
||||
@@ -189,12 +265,49 @@ class CPUPredictor : public Predictor {
|
||||
out_preds->Size() == dmat->Info().num_row_);
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
auto threads = omp_get_max_threads();
|
||||
auto m = dmlc::get<Adapter>(x);
|
||||
CHECK_EQ(m.NumColumns(), model.learner_model_param_->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
MetaInfo info;
|
||||
info.num_col_ = m.NumColumns();
|
||||
info.num_row_ = m.NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads, model.learner_model_param_->num_feature, &thread_temp);
|
||||
size_t constexpr kUnroll = 8;
|
||||
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
|
||||
&m, missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
if (x.type() == typeid(data::DenseAdapter)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(data::CSRAdapter)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Data type is not supported by CPU Predictor.";
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
if (thread_temp.size() == 0) {
|
||||
thread_temp.resize(1, RegTree::FVec());
|
||||
thread_temp[0].Init(model.learner_model_param_->num_feature);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model.learner_model_param_->num_feature);
|
||||
}
|
||||
ntree_limit *= model.learner_model_param_->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
@@ -204,16 +317,16 @@ class CPUPredictor : public Predictor {
|
||||
(model.param.size_leaf_vector + 1));
|
||||
// loop over output groups
|
||||
for (uint32_t gid = 0; gid < model.learner_model_param_->num_output_group; ++gid) {
|
||||
(*out_preds)[gid] =
|
||||
PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&thread_temp[0], 0, ntree_limit) +
|
||||
model.learner_model_param_->base_score;
|
||||
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&thread_temp_[0], 0, ntree_limit) +
|
||||
model.learner_model_param_->base_score;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature);
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature, &this->thread_temp_);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param_->num_output_group;
|
||||
@@ -230,7 +343,7 @@ class CPUPredictor : public Predictor {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec& feats = thread_temp[tid];
|
||||
RegTree::FVec &feats = thread_temp_[tid];
|
||||
feats.Fill(batch[i]);
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||
@@ -247,7 +360,7 @@ class CPUPredictor : public Predictor {
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature);
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature, &this->thread_temp_);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param_->num_output_group;
|
||||
@@ -277,7 +390,7 @@ class CPUPredictor : public Predictor {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec& feats = thread_temp[omp_get_thread_num()];
|
||||
RegTree::FVec &feats = thread_temp_[omp_get_thread_num()];
|
||||
std::vector<bst_float> this_tree_contribs(ncolumns);
|
||||
// loop over all classes
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
@@ -359,7 +472,10 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::vector<RegTree::FVec> thread_temp_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")
|
||||
|
||||
Reference in New Issue
Block a user