Avoid thread block with sparse data. (#7255)
This commit is contained in:
@@ -253,17 +253,32 @@ class CPUPredictor : public Predictor {
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) const {
|
||||
const int threads = omp_get_max_threads();
|
||||
constexpr double kDensityThresh = .5;
|
||||
size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_,
|
||||
static_cast<uint64_t>(1));
|
||||
double density = static_cast<double>(p_fmat->Info().num_nonzero_) /
|
||||
static_cast<double>(total);
|
||||
bool blocked = density > kDensityThresh;
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
InitThreadTemp(threads * (blocked ? kBlockOfRowsSize : 1),
|
||||
model.learner_model_param->num_feature, &feat_vecs);
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
|
||||
p_fmat->Info().num_row_ *
|
||||
model.learner_model_param->num_output_group);
|
||||
size_t constexpr kUnroll = 8;
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
||||
kBlockOfRowsSize>(SparsePageView<kUnroll>{&batch},
|
||||
out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
if (blocked) {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
||||
kBlockOfRowsSize>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
|
||||
} else {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, 1>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,7 +331,7 @@ class CPUPredictor : public Predictor {
|
||||
tree_end);
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
template <typename Adapter, size_t kBlockSize>
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
@@ -336,9 +351,9 @@ class CPUPredictor : public Predictor {
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
model.learner_model_param->num_feature, &thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize>(
|
||||
InitThreadTemp(threads * kBlockSize, model.learner_model_param->num_feature,
|
||||
&thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
@@ -348,16 +363,16 @@ class CPUPredictor : public Predictor {
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
this->DispatchedInplacePredict<data::CSRAdapter, 1>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter> (
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter, kBlockOfRowsSize> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter> (
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user