Prediction by indices (subsample < 1) (#6683)
* Another implementation of predicting by indices * Fixed omp parallel_for variable type * Removed SparsePageView from Updater
This commit is contained in:
parent
366f3cb9d8
commit
19a2c54265
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \file updater_quantile_hist.cc
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||
@ -7,15 +7,15 @@
|
||||
#include <dmlc/timer.h>
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
@ -111,32 +111,24 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -615,6 +607,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
tree_evaluator_ =
|
||||
TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId);
|
||||
interaction_constraints_.Reset();
|
||||
p_last_fmat_mutable_ = p_fmat;
|
||||
|
||||
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
|
||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||
@ -639,18 +632,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
|
||||
p_last_fmat_ != p_last_fmat_mutable_) {
|
||||
return false;
|
||||
}
|
||||
builder_monitor_.Start("UpdatePredictionCache");
|
||||
|
||||
std::vector<bst_float>& out_preds = p_out_preds->HostVector();
|
||||
|
||||
if (leaf_value_cache_.empty()) {
|
||||
leaf_value_cache_.resize(p_last_tree_->param.num_nodes,
|
||||
std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
CHECK_GT(out_preds.size(), 0U);
|
||||
|
||||
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
|
||||
@ -680,30 +669,60 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
}
|
||||
});
|
||||
|
||||
if (param_.subsample < 1.0f) {
|
||||
// Making a real prediction for the remaining rows
|
||||
size_t fvecs_size = feat_vecs_.size();
|
||||
feat_vecs_.resize(omp_get_max_threads(), RegTree::FVec());
|
||||
while (fvecs_size < feat_vecs_.size()) {
|
||||
feat_vecs_[fvecs_size++].Init(data->Info().num_col_);
|
||||
}
|
||||
for (auto&& batch : p_last_fmat_mutable_->GetBatches<SparsePage>()) {
|
||||
HostSparsePageView page_view = batch.GetView();
|
||||
const auto num_parallel_ops = static_cast<bst_omp_uint>(unused_rows_.size());
|
||||
common::ParallelFor(num_parallel_ops, [&](bst_omp_uint block_id) {
|
||||
RegTree::FVec &feats = feat_vecs_[omp_get_thread_num()];
|
||||
const SparsePage::Inst inst = page_view[unused_rows_[block_id]];
|
||||
feats.Fill(inst);
|
||||
|
||||
const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
|
||||
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
|
||||
p_last_tree_->GetLeafIndex<false>(feats);
|
||||
out_preds[row_num * ngroup + gid] += (*p_last_tree_)[lid].LeafValue();
|
||||
|
||||
feats.Drop(inst);
|
||||
});
|
||||
}
|
||||
}
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
std::vector<size_t>* row_indices) {
|
||||
const auto& info = fmat.Info();
|
||||
auto& rnd = common::GlobalRandom();
|
||||
std::vector<size_t>& row_indices_local = *row_indices;
|
||||
size_t* p_row_indices = row_indices_local.data();
|
||||
unused_rows_.resize(info.num_row_);
|
||||
size_t* p_row_indices_used = row_indices->data();
|
||||
size_t* p_row_indices_unused = unused_rows_.data();
|
||||
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
||||
size_t j = 0;
|
||||
size_t used = 0, unused = 0;
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
|
||||
p_row_indices[j++] = i;
|
||||
p_row_indices_used[used++] = i;
|
||||
} else {
|
||||
p_row_indices_unused[unused++] = i;
|
||||
}
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
row_indices_local.resize(j);
|
||||
row_indices->resize(used);
|
||||
unused_rows_.resize(unused);
|
||||
#else
|
||||
const size_t nthread = this->nthread_;
|
||||
std::vector<size_t> row_offsets(nthread, 0);
|
||||
std::vector<size_t> row_offsets_used(nthread, 0);
|
||||
std::vector<size_t> row_offsets_unused(nthread, 0);
|
||||
/* usage of mt19937_64 give 2x speed up for subsampling */
|
||||
std::vector<std::mt19937> rnds(nthread);
|
||||
/* create engine for each thread */
|
||||
@ -725,27 +744,51 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<Gr
|
||||
rnds[tid].discard(discard_size * tid);
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) {
|
||||
p_row_indices[ibegin + row_offsets[tid]++] = i;
|
||||
p_row_indices_used[ibegin + row_offsets_used[tid]++] = i;
|
||||
} else {
|
||||
p_row_indices_unused[ibegin + row_offsets_unused[tid]++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp barrier
|
||||
|
||||
if (tid == 0ul) {
|
||||
size_t prefix_sum_used = row_offsets_used[0];
|
||||
for (size_t i = 1; i < nthread; ++i) {
|
||||
const size_t ibegin = i * discard_size;
|
||||
|
||||
for (size_t k = 0; k < row_offsets_used[i]; ++k) {
|
||||
p_row_indices_used[prefix_sum_used + k] = p_row_indices_used[ibegin + k];
|
||||
}
|
||||
|
||||
prefix_sum_used += row_offsets_used[i];
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
row_indices->resize(prefix_sum_used);
|
||||
}
|
||||
|
||||
if (nthread == 1ul || tid == 1ul) {
|
||||
size_t prefix_sum_unused = row_offsets_unused[0];
|
||||
for (size_t i = 1; i < nthread; ++i) {
|
||||
const size_t ibegin = i * discard_size;
|
||||
|
||||
for (size_t k = 0; k < row_offsets_unused[i]; ++k) {
|
||||
p_row_indices_unused[prefix_sum_unused + k] = p_row_indices_unused[ibegin + k];
|
||||
}
|
||||
|
||||
prefix_sum_unused += row_offsets_unused[i];
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
unused_rows_.resize(prefix_sum_unused);
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
/* discard global engine */
|
||||
rnd = rnds[nthread - 1];
|
||||
size_t prefix_sum = row_offsets[0];
|
||||
for (size_t i = 1; i < nthread; ++i) {
|
||||
const size_t ibegin = i * discard_size;
|
||||
|
||||
for (size_t k = 0; k < row_offsets[i]; ++k) {
|
||||
row_indices_local[prefix_sum + k] = row_indices_local[ibegin + k];
|
||||
}
|
||||
prefix_sum += row_offsets[i];
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
row_indices_local.resize(prefix_sum);
|
||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
@ -764,8 +807,6 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
{
|
||||
// initialize the row set
|
||||
row_set_collection_.Clear();
|
||||
// clear local prediction cache
|
||||
leaf_value_cache_.clear();
|
||||
// initialize histogram collection
|
||||
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||
hist_.Init(nbins);
|
||||
@ -793,6 +834,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
<< "Only uniform sampling is supported, "
|
||||
<< "gradient-based sampling is only support by GPU Hist.";
|
||||
InitSampling(gpair, fmat, &row_indices);
|
||||
// We should check that the partitioning was done correctly
|
||||
// and each row of the dataset fell into exactly one of the categories
|
||||
CHECK_EQ(row_indices.size() + unused_rows_.size(), info.num_row_);
|
||||
} else {
|
||||
MemStackAllocator<bool, 128> buff(this->nthread_);
|
||||
bool* p_buff = buff.Get();
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2018 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \file updater_quantile_hist.h
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Chen, Egor Smirnov
|
||||
@ -11,13 +11,12 @@
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <queue>
|
||||
#include <iomanip>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
@ -409,6 +408,10 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
common::ColumnSampler column_sampler_;
|
||||
// the internal row sets
|
||||
RowSetCollection row_set_collection_;
|
||||
// tree rows that were not used for current training
|
||||
std::vector<size_t> unused_rows_;
|
||||
// feature vectors for subsampled prediction
|
||||
std::vector<RegTree::FVec> feat_vecs_;
|
||||
// the temp space for split
|
||||
std::vector<RowSetCollection::Split> row_split_tloc_;
|
||||
std::vector<SplitEntry> best_split_tloc_;
|
||||
@ -422,8 +425,6 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||
of InitNewNode() */
|
||||
uint32_t fid_least_bins_;
|
||||
/*! \brief local prediction cache; maps node id to leaf value */
|
||||
std::vector<float> leaf_value_cache_;
|
||||
|
||||
GHistBuilder<GradientSumT> hist_builder_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
@ -435,6 +436,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
// back pointers to tree and data matrix
|
||||
const RegTree* p_last_tree_;
|
||||
DMatrix const* const p_last_fmat_;
|
||||
DMatrix* p_last_fmat_mutable_;
|
||||
|
||||
using ExpandQueue =
|
||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
|
||||
@ -183,7 +183,7 @@ TEST(CpuPredictor, InplacePredict) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
void TestUpdatePredictionCache(bool use_subsampling) {
|
||||
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
|
||||
LearnerModelParam mparam;
|
||||
mparam.num_feature = kCols;
|
||||
@ -198,6 +198,9 @@ TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
std::map<std::string, std::string> cfg;
|
||||
cfg["tree_method"] = "hist";
|
||||
cfg["predictor"] = "cpu_predictor";
|
||||
if (use_subsampling) {
|
||||
cfg["subsample"] = "0.5";
|
||||
}
|
||||
Args args = {cfg.cbegin(), cfg.cend()};
|
||||
gbm->Configure(args);
|
||||
|
||||
@ -226,6 +229,11 @@ TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
TestUpdatePredictionCache(false);
|
||||
TestUpdatePredictionCache(true);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("cpu_predictor");
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018-2019 by Contributors
|
||||
* Copyright 2018-2021 by Contributors
|
||||
*/
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
@ -107,19 +107,45 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
const size_t nthreads = omp_get_num_threads();
|
||||
// save state of global rng engine
|
||||
auto initial_rnd = common::GlobalRandom();
|
||||
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
|
||||
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
|
||||
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
|
||||
const std::vector<size_t>& second,
|
||||
size_t nrows) {
|
||||
std::vector<size_t> arr_union(nrows);
|
||||
for (auto&& row_indice : first) {
|
||||
++arr_union[row_indice];
|
||||
}
|
||||
for (auto&& row_indice : second) {
|
||||
++arr_union[row_indice];
|
||||
}
|
||||
for (auto&& row_cnt : arr_union) {
|
||||
ASSERT_EQ(row_cnt, 1ul);
|
||||
}
|
||||
};
|
||||
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
|
||||
p_fmat->Info().num_row_);
|
||||
|
||||
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
|
||||
omp_set_num_threads(i_nthreads);
|
||||
// return initial state of global rng engine
|
||||
common::GlobalRandom() = initial_rnd;
|
||||
this->unused_rows_ = unused_rows_cpy;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
|
||||
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
|
||||
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
|
||||
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
|
||||
}
|
||||
std::vector<size_t>& unused_row_indices = this->unused_rows_;
|
||||
ASSERT_EQ(unused_row_indices_initial.size(), unused_row_indices.size());
|
||||
for (size_t i = 0; i < unused_row_indices_initial.size(); ++i) {
|
||||
ASSERT_EQ(unused_row_indices_initial[i], unused_row_indices[i]);
|
||||
}
|
||||
check_each_row_occurs_in_one_of_arrays(row_indices, unused_row_indices,
|
||||
p_fmat->Info().num_row_);
|
||||
}
|
||||
omp_set_num_threads(nthreads);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user