Use view for SparsePage exclusively. (#6590)

This commit is contained in:
Jiaming Yuan
2021-01-11 18:04:55 +08:00
committed by GitHub
parent 78f2cd83d7
commit f2f7dd87b8
23 changed files with 151 additions and 113 deletions

View File

@@ -59,8 +59,9 @@ class BaseMaker: public TreeUpdater {
-std::numeric_limits<bst_float>::max());
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
auto page = batch.GetView();
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
auto c = batch[fid];
auto c = page[fid];
if (c.size() != 0) {
CHECK_LT(fid * 2, fminmax_.size());
fminmax_[fid * 2 + 0] =
@@ -249,8 +250,9 @@ class BaseMaker: public TreeUpdater {
inline void CorrectNonDefaultPositionByBatch(
const SparsePage &batch, const std::vector<bst_uint> &sorted_split_set,
const RegTree &tree) {
auto page = batch.GetView();
for (size_t fid = 0; fid < batch.Size(); ++fid) {
auto col = batch[fid];
auto col = page[fid];
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
if (it != sorted_split_set.end() && *it == fid) {
@@ -308,10 +310,11 @@ class BaseMaker: public TreeUpdater {
std::vector<unsigned> fsplits;
this->GetSplitSet(nodes, tree, &fsplits);
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
auto page = batch.GetView();
for (auto fid : fsplits) {
auto col = batch[fid];
auto col = page[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const bst_float fvalue = col[j].fvalue;

View File

@@ -77,8 +77,9 @@ class ColMaker: public TreeUpdater {
if (column_densities_.empty()) {
std::vector<size_t> column_size(dmat->Info().num_col_);
for (const auto &batch : dmat->GetBatches<SortedCSCPage>()) {
auto page = batch.GetView();
for (auto i = 0u; i < batch.Size(); i++) {
column_size[i] += batch[i].size();
column_size[i] += page[i].size();
}
}
column_densities_.resize(column_size.size());
@@ -447,13 +448,14 @@ class ColMaker: public TreeUpdater {
#endif // defined(_OPENMP)
{
dmlc::OMPException omp_handler;
auto page = batch.GetView();
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < num_features; ++i) {
omp_handler.Run([&]() {
auto evaluator = tree_evaluator_.GetEvaluator();
bst_feature_t const fid = feat_set[i];
int32_t const tid = omp_get_thread_num();
auto c = batch[fid];
auto c = page[fid];
const bool ind =
c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
if (colmaker_train_param_.NeedForwardSearch(
@@ -562,8 +564,9 @@ class ColMaker: public TreeUpdater {
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
auto page = batch.GetView();
for (auto fid : fsplits) {
auto col = batch[fid];
auto col = page[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {

View File

@@ -338,6 +338,7 @@ class CQHistMaker: public HistMaker {
thread_hist_.resize(omp_get_max_threads());
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
auto page = batch.GetView();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(fset.size());
#pragma omp parallel for schedule(dynamic, 1)
@@ -345,7 +346,7 @@ class CQHistMaker: public HistMaker {
int fid = fset[i];
int offset = feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[fid], info, tree,
this->UpdateHistCol(gpair, page[fid], info, tree,
fset, offset,
&thread_hist_[omp_get_thread_num()]);
}
@@ -413,15 +414,15 @@ class CQHistMaker: public HistMaker {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
auto page = batch.GetView();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(work_set_.size());
#pragma omp parallel for schedule(dynamic, 1)
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int fid = work_set_[i];
int offset = feat2workindex_[fid];
if (offset >= 0) {
this->UpdateSketchCol(gpair, batch[fid], tree,
this->UpdateSketchCol(gpair, page[fid], tree,
work_set_size, offset,
&thread_sketch_[omp_get_thread_num()]);
}
@@ -696,6 +697,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
auto page = batch.GetView();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
@@ -704,7 +706,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
int fid = this->work_set_[i];
int offset = this->feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[fid], info, tree,
this->UpdateHistCol(gpair, page[fid], info, tree,
fset, offset,
&this->thread_hist_[omp_get_thread_num()]);
}

View File

@@ -69,11 +69,12 @@ class TreeRefresher: public TreeUpdater {
const MetaInfo &info = p_fmat->Info();
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView();
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
SparsePage::Inst inst = batch[i];
SparsePage::Inst inst = page[i];
const int tid = omp_get_thread_num();
const auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
RegTree::FVec &feats = fvec_temp[tid];