Avoid thread block with sparse data. (#7255)
This commit is contained in:
parent
ca17f8a5fc
commit
d8a549e6ac
@ -253,17 +253,32 @@ class CPUPredictor : public Predictor {
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) const {
|
||||
const int threads = omp_get_max_threads();
|
||||
constexpr double kDensityThresh = .5;
|
||||
size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_,
|
||||
static_cast<uint64_t>(1));
|
||||
double density = static_cast<double>(p_fmat->Info().num_nonzero_) /
|
||||
static_cast<double>(total);
|
||||
bool blocked = density > kDensityThresh;
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
InitThreadTemp(threads * (blocked ? kBlockOfRowsSize : 1),
|
||||
model.learner_model_param->num_feature, &feat_vecs);
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
|
||||
p_fmat->Info().num_row_ *
|
||||
model.learner_model_param->num_output_group);
|
||||
size_t constexpr kUnroll = 8;
|
||||
if (blocked) {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
||||
kBlockOfRowsSize>(SparsePageView<kUnroll>{&batch},
|
||||
out_preds, model, tree_begin,
|
||||
kBlockOfRowsSize>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
|
||||
} else {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, 1>(
|
||||
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &feat_vecs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -316,7 +331,7 @@ class CPUPredictor : public Predictor {
|
||||
tree_end);
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
template <typename Adapter, size_t kBlockSize>
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
@ -336,9 +351,9 @@ class CPUPredictor : public Predictor {
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
model.learner_model_param->num_feature, &thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize>(
|
||||
InitThreadTemp(threads * kBlockSize, model.learner_model_param->num_feature,
|
||||
&thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
@ -348,16 +363,16 @@ class CPUPredictor : public Predictor {
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
this->DispatchedInplacePredict<data::CSRAdapter, 1>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter> (
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter, kBlockOfRowsSize> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter> (
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
return false;
|
||||
|
||||
@ -247,7 +247,7 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins,
|
||||
ASSERT_TRUE(is_unique);
|
||||
|
||||
x.resize(n_uniques);
|
||||
for (size_t i = 0; i < n_uniques; ++i) {
|
||||
for (decltype(n_uniques) i = 0; i < n_uniques; ++i) {
|
||||
ASSERT_EQ(x[i], values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,4 +247,9 @@ TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
TEST(CpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("cpu_predictor");
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, Sparse) {
|
||||
TestSparsePrediction(0.2, "cpu_predictor");
|
||||
TestSparsePrediction(0.8, "cpu_predictor");
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -256,5 +256,10 @@ TEST(GPUPredictor, PredictLeafBasic) {
|
||||
ASSERT_EQ(v, 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, Sparse) {
|
||||
TestSparsePrediction(0.2, "gpu_predictor");
|
||||
TestSparsePrediction(0.8, "gpu_predictor");
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include "test_predictor.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/common/categorical.h"
|
||||
#include "../../../src/common/bitfield.h"
|
||||
@ -355,4 +356,57 @@ void TestIterationRange(std::string name) {
|
||||
ASSERT_EQ(h_sliced, h_range);
|
||||
}
|
||||
}
|
||||
|
||||
void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||
size_t constexpr kRows = 512, kCols = 128;
|
||||
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
learner->Configure();
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
learner->UpdateOneIter(i, Xy);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> sparse_predt;
|
||||
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
learner.reset(Learner::Create({Xy}));
|
||||
learner->LoadModel(model);
|
||||
|
||||
learner->SetParam("predictor", predictor);
|
||||
learner->Predict(Xy, false, &sparse_predt, 0, 0);
|
||||
|
||||
std::vector<float> with_nan(kRows * kCols, std::numeric_limits<float>::quiet_NaN());
|
||||
for (auto const& page : Xy->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
auto row = batch[i];
|
||||
for (auto e : row) {
|
||||
with_nan[i * kCols + e.index] = e.fvalue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
learner->SetParam("predictor", "cpu_predictor");
|
||||
// Xcode_12.4 doesn't compile with `std::make_shared`.
|
||||
auto dense = std::shared_ptr<data::DenseAdapter>(
|
||||
new data::DenseAdapter(with_nan.data(), kRows, kCols));
|
||||
HostDeviceVector<float> *p_dense_predt;
|
||||
learner->InplacePredict(dmlc::any(dense), nullptr, PredictionType::kValue,
|
||||
std::numeric_limits<float>::quiet_NaN(), &p_dense_predt,
|
||||
0, 0);
|
||||
|
||||
auto const& dense_predt = *p_dense_predt;
|
||||
if (predictor == "cpu_predictor") {
|
||||
ASSERT_EQ(dense_predt.HostVector(), sparse_predt.HostVector());
|
||||
} else {
|
||||
auto const &h_dense = dense_predt.HostVector();
|
||||
auto const &h_sparse = sparse_predt.HostVector();
|
||||
ASSERT_EQ(h_dense.size(), h_sparse.size());
|
||||
for (size_t i = 0; i < h_dense.size(); ++i) {
|
||||
ASSERT_FLOAT_EQ(h_dense[i], h_sparse[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -70,6 +70,8 @@ void TestCategoricalPrediction(std::string name);
|
||||
void TestCategoricalPredictLeaf(StringView name);
|
||||
|
||||
void TestIterationRange(std::string name);
|
||||
|
||||
void TestSparsePrediction(float sparsity, std::string predictor);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user