Prediction by indices (subsample < 1) (#6683)

* Another implementation of predicting by indices

* Fixed omp parallel_for variable type

* Removed SparsePageView from Updater
This commit is contained in:
Igor Rukhovich
2021-03-16 05:08:20 +03:00
committed by GitHub
parent 366f3cb9d8
commit 19a2c54265
4 changed files with 139 additions and 59 deletions

View File

@@ -183,7 +183,7 @@ TEST(CpuPredictor, InplacePredict) {
}
}
TEST(CpuPredictor, UpdatePredictionCache) {
void TestUpdatePredictionCache(bool use_subsampling) {
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
LearnerModelParam mparam;
mparam.num_feature = kCols;
@@ -198,6 +198,9 @@ TEST(CpuPredictor, UpdatePredictionCache) {
std::map<std::string, std::string> cfg;
cfg["tree_method"] = "hist";
cfg["predictor"] = "cpu_predictor";
if (use_subsampling) {
cfg["subsample"] = "0.5";
}
Args args = {cfg.cbegin(), cfg.cend()};
gbm->Configure(args);
@@ -226,6 +229,11 @@ TEST(CpuPredictor, UpdatePredictionCache) {
}
}
TEST(CpuPredictor, UpdatePredictionCache) {
TestUpdatePredictionCache(false);
TestUpdatePredictionCache(true);
}
TEST(CpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("cpu_predictor");
}

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2018-2019 by Contributors
* Copyright 2018-2021 by Contributors
*/
#include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h>
@@ -107,19 +107,45 @@ class QuantileHistMock : public QuantileHistMaker {
const size_t nthreads = omp_get_num_threads();
// save state of global rng engine
auto initial_rnd = common::GlobalRandom();
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
const std::vector<size_t>& second,
size_t nrows) {
std::vector<size_t> arr_union(nrows);
for (auto&& row_indice : first) {
++arr_union[row_indice];
}
for (auto&& row_indice : second) {
++arr_union[row_indice];
}
for (auto&& row_cnt : arr_union) {
ASSERT_EQ(row_cnt, 1ul);
}
};
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
p_fmat->Info().num_row_);
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
omp_set_num_threads(i_nthreads);
// return initial state of global rng engine
common::GlobalRandom() = initial_rnd;
this->unused_rows_ = unused_rows_cpy;
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
}
std::vector<size_t>& unused_row_indices = this->unused_rows_;
ASSERT_EQ(unused_row_indices_initial.size(), unused_row_indices.size());
for (size_t i = 0; i < unused_row_indices_initial.size(); ++i) {
ASSERT_EQ(unused_row_indices_initial[i], unused_row_indices[i]);
}
check_each_row_occurs_in_one_of_arrays(row_indices, unused_row_indices,
p_fmat->Info().num_row_);
}
omp_set_num_threads(nthreads);
}