Prediction by indices (subsample < 1) (#6683)

* Another implementation of predicting by indices

* Fixed omp parallel_for variable type

* Removed SparsePageView from Updater
This commit is contained in:
Igor Rukhovich 2021-03-16 05:08:20 +03:00 committed by GitHub
parent 366f3cb9d8
commit 19a2c54265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 139 additions and 59 deletions

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2020 by Contributors * Copyright 2017-2021 by Contributors
* \file updater_quantile_hist.cc * \file updater_quantile_hist.cc
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov * \author Philip Cho, Tianqi Checn, Egor Smirnov
@ -7,15 +7,15 @@
#include <dmlc/timer.h> #include <dmlc/timer.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <cmath>
#include <memory>
#include <vector>
#include <algorithm> #include <algorithm>
#include <queue> #include <cmath>
#include <iomanip> #include <iomanip>
#include <memory>
#include <numeric> #include <numeric>
#include <queue>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
@ -111,32 +111,24 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
bool QuantileHistMaker::UpdatePredictionCache( bool QuantileHistMaker::UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) { const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
if (param_.subsample < 1.0f) { if (hist_maker_param_.single_precision_histogram && float_builder_) {
return false; return float_builder_->UpdatePredictionCache(data, out_preds);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds);
} else { } else {
if (hist_maker_param_.single_precision_histogram && float_builder_) { return false;
return float_builder_->UpdatePredictionCache(data, out_preds);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds);
} else {
return false;
}
} }
} }
bool QuantileHistMaker::UpdatePredictionCacheMulticlass( bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
const DMatrix* data, const DMatrix* data,
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) { HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
if (param_.subsample < 1.0f) { if (hist_maker_param_.single_precision_histogram && float_builder_) {
return false; return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else { } else {
if (hist_maker_param_.single_precision_histogram && float_builder_) { return false;
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else {
return false;
}
} }
} }
@ -615,6 +607,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
tree_evaluator_ = tree_evaluator_ =
TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId); TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId);
interaction_constraints_.Reset(); interaction_constraints_.Reset();
p_last_fmat_mutable_ = p_fmat;
this->InitData(gmat, gpair_h, *p_fmat, *p_tree); this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
if (param_.grow_policy == TrainParam::kLossGuide) { if (param_.grow_policy == TrainParam::kLossGuide) {
@ -639,18 +632,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) { HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update(). // conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
p_last_fmat_ != p_last_fmat_mutable_) {
return false; return false;
} }
builder_monitor_.Start("UpdatePredictionCache"); builder_monitor_.Start("UpdatePredictionCache");
std::vector<bst_float>& out_preds = p_out_preds->HostVector(); std::vector<bst_float>& out_preds = p_out_preds->HostVector();
if (leaf_value_cache_.empty()) {
leaf_value_cache_.resize(p_last_tree_->param.num_nodes,
std::numeric_limits<float>::infinity());
}
CHECK_GT(out_preds.size(), 0U); CHECK_GT(out_preds.size(), 0U);
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
@ -680,30 +669,60 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
} }
}); });
if (param_.subsample < 1.0f) {
// Making a real prediction for the remaining rows
size_t fvecs_size = feat_vecs_.size();
feat_vecs_.resize(omp_get_max_threads(), RegTree::FVec());
while (fvecs_size < feat_vecs_.size()) {
feat_vecs_[fvecs_size++].Init(data->Info().num_col_);
}
for (auto&& batch : p_last_fmat_mutable_->GetBatches<SparsePage>()) {
HostSparsePageView page_view = batch.GetView();
const auto num_parallel_ops = static_cast<bst_omp_uint>(unused_rows_.size());
common::ParallelFor(num_parallel_ops, [&](bst_omp_uint block_id) {
RegTree::FVec &feats = feat_vecs_[omp_get_thread_num()];
const SparsePage::Inst inst = page_view[unused_rows_[block_id]];
feats.Fill(inst);
const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
p_last_tree_->GetLeafIndex<false>(feats);
out_preds[row_num * ngroup + gid] += (*p_last_tree_)[lid].LeafValue();
feats.Drop(inst);
});
}
}
builder_monitor_.Stop("UpdatePredictionCache"); builder_monitor_.Stop("UpdatePredictionCache");
return true; return true;
} }
template<typename GradientSumT> template<typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<GradientPair>& gpair, void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<GradientPair>& gpair,
const DMatrix& fmat, const DMatrix& fmat,
std::vector<size_t>* row_indices) { std::vector<size_t>* row_indices) {
const auto& info = fmat.Info(); const auto& info = fmat.Info();
auto& rnd = common::GlobalRandom(); auto& rnd = common::GlobalRandom();
std::vector<size_t>& row_indices_local = *row_indices; unused_rows_.resize(info.num_row_);
size_t* p_row_indices = row_indices_local.data(); size_t* p_row_indices_used = row_indices->data();
size_t* p_row_indices_unused = unused_rows_.data();
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG #if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
std::bernoulli_distribution coin_flip(param_.subsample); std::bernoulli_distribution coin_flip(param_.subsample);
size_t j = 0; size_t used = 0, unused = 0;
for (size_t i = 0; i < info.num_row_; ++i) { for (size_t i = 0; i < info.num_row_; ++i) {
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) { if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
p_row_indices[j++] = i; p_row_indices_used[used++] = i;
} else {
p_row_indices_unused[unused++] = i;
} }
} }
/* resize row_indices to reduce memory */ /* resize row_indices to reduce memory */
row_indices_local.resize(j); row_indices->resize(used);
unused_rows_.resize(unused);
#else #else
const size_t nthread = this->nthread_; const size_t nthread = this->nthread_;
std::vector<size_t> row_offsets(nthread, 0); std::vector<size_t> row_offsets_used(nthread, 0);
std::vector<size_t> row_offsets_unused(nthread, 0);
/* usage of mt19937_64 give 2x speed up for subsampling */ /* usage of mt19937_64 give 2x speed up for subsampling */
std::vector<std::mt19937> rnds(nthread); std::vector<std::mt19937> rnds(nthread);
/* create engine for each thread */ /* create engine for each thread */
@ -725,27 +744,51 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<Gr
rnds[tid].discard(discard_size * tid); rnds[tid].discard(discard_size * tid);
for (size_t i = ibegin; i < iend; ++i) { for (size_t i = ibegin; i < iend; ++i) {
if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) { if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) {
p_row_indices[ibegin + row_offsets[tid]++] = i; p_row_indices_used[ibegin + row_offsets_used[tid]++] = i;
} else {
p_row_indices_unused[ibegin + row_offsets_unused[tid]++] = i;
} }
} }
#pragma omp barrier
if (tid == 0ul) {
size_t prefix_sum_used = row_offsets_used[0];
for (size_t i = 1; i < nthread; ++i) {
const size_t ibegin = i * discard_size;
for (size_t k = 0; k < row_offsets_used[i]; ++k) {
p_row_indices_used[prefix_sum_used + k] = p_row_indices_used[ibegin + k];
}
prefix_sum_used += row_offsets_used[i];
}
/* resize row_indices to reduce memory */
row_indices->resize(prefix_sum_used);
}
if (nthread == 1ul || tid == 1ul) {
size_t prefix_sum_unused = row_offsets_unused[0];
for (size_t i = 1; i < nthread; ++i) {
const size_t ibegin = i * discard_size;
for (size_t k = 0; k < row_offsets_unused[i]; ++k) {
p_row_indices_unused[prefix_sum_unused + k] = p_row_indices_unused[ibegin + k];
}
prefix_sum_unused += row_offsets_unused[i];
}
/* resize row_indices to reduce memory */
unused_rows_.resize(prefix_sum_unused);
}
}); });
} }
exc.Rethrow(); exc.Rethrow();
/* discard global engine */ /* discard global engine */
rnd = rnds[nthread - 1]; rnd = rnds[nthread - 1];
size_t prefix_sum = row_offsets[0];
for (size_t i = 1; i < nthread; ++i) {
const size_t ibegin = i * discard_size;
for (size_t k = 0; k < row_offsets[i]; ++k) {
row_indices_local[prefix_sum + k] = row_indices_local[ibegin + k];
}
prefix_sum += row_offsets[i];
}
/* resize row_indices to reduce memory */
row_indices_local.resize(prefix_sum);
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
} }
template<typename GradientSumT> template<typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix& gmat, void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair, const std::vector<GradientPair>& gpair,
@ -764,8 +807,6 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
{ {
// initialize the row set // initialize the row set
row_set_collection_.Clear(); row_set_collection_.Clear();
// clear local prediction cache
leaf_value_cache_.clear();
// initialize histogram collection // initialize histogram collection
uint32_t nbins = gmat.cut.Ptrs().back(); uint32_t nbins = gmat.cut.Ptrs().back();
hist_.Init(nbins); hist_.Init(nbins);
@ -793,6 +834,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
<< "Only uniform sampling is supported, " << "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist."; << "gradient-based sampling is only support by GPU Hist.";
InitSampling(gpair, fmat, &row_indices); InitSampling(gpair, fmat, &row_indices);
// We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories
CHECK_EQ(row_indices.size() + unused_rows_.size(), info.num_row_);
} else { } else {
MemStackAllocator<bool, 128> buff(this->nthread_); MemStackAllocator<bool, 128> buff(this->nthread_);
bool* p_buff = buff.Get(); bool* p_buff = buff.Get();

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2018 by Contributors * Copyright 2017-2021 by Contributors
* \file updater_quantile_hist.h * \file updater_quantile_hist.h
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Chen, Egor Smirnov * \author Philip Cho, Tianqi Chen, Egor Smirnov
@ -11,13 +11,12 @@
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <memory>
#include <vector>
#include <string>
#include <queue>
#include <iomanip> #include <iomanip>
#include <unordered_map> #include <memory>
#include <queue>
#include <string>
#include <utility> #include <utility>
#include <vector>
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/json.h" #include "xgboost/json.h"
@ -409,6 +408,10 @@ class QuantileHistMaker: public TreeUpdater {
common::ColumnSampler column_sampler_; common::ColumnSampler column_sampler_;
// the internal row sets // the internal row sets
RowSetCollection row_set_collection_; RowSetCollection row_set_collection_;
// tree rows that were not used for current training
std::vector<size_t> unused_rows_;
// feature vectors for subsampled prediction
std::vector<RegTree::FVec> feat_vecs_;
// the temp space for split // the temp space for split
std::vector<RowSetCollection::Split> row_split_tloc_; std::vector<RowSetCollection::Split> row_split_tloc_;
std::vector<SplitEntry> best_split_tloc_; std::vector<SplitEntry> best_split_tloc_;
@ -422,8 +425,6 @@ class QuantileHistMaker: public TreeUpdater {
/*! \brief feature with least # of bins. to be used for dense specialization /*! \brief feature with least # of bins. to be used for dense specialization
of InitNewNode() */ of InitNewNode() */
uint32_t fid_least_bins_; uint32_t fid_least_bins_;
/*! \brief local prediction cache; maps node id to leaf value */
std::vector<float> leaf_value_cache_;
GHistBuilder<GradientSumT> hist_builder_; GHistBuilder<GradientSumT> hist_builder_;
std::unique_ptr<TreeUpdater> pruner_; std::unique_ptr<TreeUpdater> pruner_;
@ -435,6 +436,7 @@ class QuantileHistMaker: public TreeUpdater {
// back pointers to tree and data matrix // back pointers to tree and data matrix
const RegTree* p_last_tree_; const RegTree* p_last_tree_;
DMatrix const* const p_last_fmat_; DMatrix const* const p_last_fmat_;
DMatrix* p_last_fmat_mutable_;
using ExpandQueue = using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>, std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,

View File

@ -183,7 +183,7 @@ TEST(CpuPredictor, InplacePredict) {
} }
} }
TEST(CpuPredictor, UpdatePredictionCache) { void TestUpdatePredictionCache(bool use_subsampling) {
size_t constexpr kRows = 64, kCols = 16, kClasses = 4; size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
LearnerModelParam mparam; LearnerModelParam mparam;
mparam.num_feature = kCols; mparam.num_feature = kCols;
@ -198,6 +198,9 @@ TEST(CpuPredictor, UpdatePredictionCache) {
std::map<std::string, std::string> cfg; std::map<std::string, std::string> cfg;
cfg["tree_method"] = "hist"; cfg["tree_method"] = "hist";
cfg["predictor"] = "cpu_predictor"; cfg["predictor"] = "cpu_predictor";
if (use_subsampling) {
cfg["subsample"] = "0.5";
}
Args args = {cfg.cbegin(), cfg.cend()}; Args args = {cfg.cbegin(), cfg.cend()};
gbm->Configure(args); gbm->Configure(args);
@ -226,6 +229,11 @@ TEST(CpuPredictor, UpdatePredictionCache) {
} }
} }
TEST(CpuPredictor, UpdatePredictionCache) {
TestUpdatePredictionCache(false);
TestUpdatePredictionCache(true);
}
TEST(CpuPredictor, LesserFeatures) { TEST(CpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("cpu_predictor"); TestPredictionWithLesserFeatures("cpu_predictor");
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2018-2019 by Contributors * Copyright 2018-2021 by Contributors
*/ */
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
@ -107,19 +107,45 @@ class QuantileHistMock : public QuantileHistMaker {
const size_t nthreads = omp_get_num_threads(); const size_t nthreads = omp_get_num_threads();
// save state of global rng engine // save state of global rng engine
auto initial_rnd = common::GlobalRandom(); auto initial_rnd = common::GlobalRandom();
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
RealImpl::InitData(gmat, gpair, *p_fmat, tree); RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data()); std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
const std::vector<size_t>& second,
size_t nrows) {
std::vector<size_t> arr_union(nrows);
for (auto&& row_indice : first) {
++arr_union[row_indice];
}
for (auto&& row_indice : second) {
++arr_union[row_indice];
}
for (auto&& row_cnt : arr_union) {
ASSERT_EQ(row_cnt, 1ul);
}
};
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
p_fmat->Info().num_row_);
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) { for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
omp_set_num_threads(i_nthreads); omp_set_num_threads(i_nthreads);
// return initial state of global rng engine // return initial state of global rng engine
common::GlobalRandom() = initial_rnd; common::GlobalRandom() = initial_rnd;
this->unused_rows_ = unused_rows_cpy;
RealImpl::InitData(gmat, gpair, *p_fmat, tree); RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data()); std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
ASSERT_EQ(row_indices_initial.size(), row_indices.size()); ASSERT_EQ(row_indices_initial.size(), row_indices.size());
for (size_t i = 0; i < row_indices_initial.size(); ++i) { for (size_t i = 0; i < row_indices_initial.size(); ++i) {
ASSERT_EQ(row_indices_initial[i], row_indices[i]); ASSERT_EQ(row_indices_initial[i], row_indices[i]);
} }
std::vector<size_t>& unused_row_indices = this->unused_rows_;
ASSERT_EQ(unused_row_indices_initial.size(), unused_row_indices.size());
for (size_t i = 0; i < unused_row_indices_initial.size(); ++i) {
ASSERT_EQ(unused_row_indices_initial[i], unused_row_indices[i]);
}
check_each_row_occurs_in_one_of_arrays(row_indices, unused_row_indices,
p_fmat->Info().num_row_);
} }
omp_set_num_threads(nthreads); omp_set_num_threads(nthreads);
} }