Use view for SparsePage exclusively. (#6590)
This commit is contained in:
@@ -59,8 +59,9 @@ class BaseMaker: public TreeUpdater {
|
||||
-std::numeric_limits<bst_float>::max());
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
|
||||
auto c = batch[fid];
|
||||
auto c = page[fid];
|
||||
if (c.size() != 0) {
|
||||
CHECK_LT(fid * 2, fminmax_.size());
|
||||
fminmax_[fid * 2 + 0] =
|
||||
@@ -249,8 +250,9 @@ class BaseMaker: public TreeUpdater {
|
||||
inline void CorrectNonDefaultPositionByBatch(
|
||||
const SparsePage &batch, const std::vector<bst_uint> &sorted_split_set,
|
||||
const RegTree &tree) {
|
||||
auto page = batch.GetView();
|
||||
for (size_t fid = 0; fid < batch.Size(); ++fid) {
|
||||
auto col = batch[fid];
|
||||
auto col = page[fid];
|
||||
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
|
||||
|
||||
if (it != sorted_split_set.end() && *it == fid) {
|
||||
@@ -308,10 +310,11 @@ class BaseMaker: public TreeUpdater {
|
||||
std::vector<unsigned> fsplits;
|
||||
this->GetSplitSet(nodes, tree, &fsplits);
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
for (auto fid : fsplits) {
|
||||
auto col = batch[fid];
|
||||
auto col = page[fid];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
const bst_uint ridx = col[j].index;
|
||||
const bst_float fvalue = col[j].fvalue;
|
||||
|
||||
@@ -77,8 +77,9 @@ class ColMaker: public TreeUpdater {
|
||||
if (column_densities_.empty()) {
|
||||
std::vector<size_t> column_size(dmat->Info().num_col_);
|
||||
for (const auto &batch : dmat->GetBatches<SortedCSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
for (auto i = 0u; i < batch.Size(); i++) {
|
||||
column_size[i] += batch[i].size();
|
||||
column_size[i] += page[i].size();
|
||||
}
|
||||
}
|
||||
column_densities_.resize(column_size.size());
|
||||
@@ -447,13 +448,14 @@ class ColMaker: public TreeUpdater {
|
||||
#endif // defined(_OPENMP)
|
||||
{
|
||||
dmlc::OMPException omp_handler;
|
||||
auto page = batch.GetView();
|
||||
#pragma omp parallel for schedule(dynamic, batch_size)
|
||||
for (bst_omp_uint i = 0; i < num_features; ++i) {
|
||||
omp_handler.Run([&]() {
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
bst_feature_t const fid = feat_set[i];
|
||||
int32_t const tid = omp_get_thread_num();
|
||||
auto c = batch[fid];
|
||||
auto c = page[fid];
|
||||
const bool ind =
|
||||
c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
|
||||
if (colmaker_train_param_.NeedForwardSearch(
|
||||
@@ -562,8 +564,9 @@ class ColMaker: public TreeUpdater {
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
for (auto fid : fsplits) {
|
||||
auto col = batch[fid];
|
||||
auto col = page[fid];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
|
||||
@@ -338,6 +338,7 @@ class CQHistMaker: public HistMaker {
|
||||
thread_hist_.resize(omp_get_max_threads());
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(fset.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
@@ -345,7 +346,7 @@ class CQHistMaker: public HistMaker {
|
||||
int fid = fset[i];
|
||||
int offset = feat2workindex_[fid];
|
||||
if (offset >= 0) {
|
||||
this->UpdateHistCol(gpair, batch[fid], info, tree,
|
||||
this->UpdateHistCol(gpair, page[fid], info, tree,
|
||||
fset, offset,
|
||||
&thread_hist_[omp_get_thread_num()]);
|
||||
}
|
||||
@@ -413,15 +414,15 @@ class CQHistMaker: public HistMaker {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
|
||||
|
||||
auto page = batch.GetView();
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(work_set_.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
int fid = work_set_[i];
|
||||
int offset = feat2workindex_[fid];
|
||||
if (offset >= 0) {
|
||||
this->UpdateSketchCol(gpair, batch[fid], tree,
|
||||
this->UpdateSketchCol(gpair, page[fid], tree,
|
||||
work_set_size, offset,
|
||||
&thread_sketch_[omp_get_thread_num()]);
|
||||
}
|
||||
@@ -696,6 +697,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
|
||||
auto page = batch.GetView();
|
||||
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
|
||||
@@ -704,7 +706,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
|
||||
int fid = this->work_set_[i];
|
||||
int offset = this->feat2workindex_[fid];
|
||||
if (offset >= 0) {
|
||||
this->UpdateHistCol(gpair, batch[fid], info, tree,
|
||||
this->UpdateHistCol(gpair, page[fid], info, tree,
|
||||
fset, offset,
|
||||
&this->thread_hist_[omp_get_thread_num()]);
|
||||
}
|
||||
|
||||
@@ -69,11 +69,12 @@ class TreeRefresher: public TreeUpdater {
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
auto page = batch.GetView();
|
||||
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
|
||||
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
SparsePage::Inst inst = batch[i];
|
||||
SparsePage::Inst inst = page[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
const auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = fvec_temp[tid];
|
||||
|
||||
Reference in New Issue
Block a user