Refactor DMatrix to return batches of different page types (#4686)

* Use explicit template parameter for specifying page type.
This commit is contained in:
Rong Ou
2019-08-03 12:10:34 -07:00
committed by Jiaming Yuan
parent e930a8e54f
commit 6edddd7966
41 changed files with 477 additions and 470 deletions

View File

@@ -45,7 +45,7 @@ class BaseMaker: public TreeUpdater {
std::fill(fminmax_.begin(), fminmax_.end(),
-std::numeric_limits<bst_float>::max());
// start accumulating statistics
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
auto c = batch[fid];
if (c.size() != 0) {
@@ -302,7 +302,7 @@ class BaseMaker: public TreeUpdater {
const RegTree &tree) {
std::vector<unsigned> fsplits;
this->GetSplitSet(nodes, tree, &fsplits);
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size());

View File

@@ -637,7 +637,7 @@ class ColMaker: public TreeUpdater {
DMatrix *p_fmat,
RegTree *p_tree) {
auto feat_set = column_sampler_.GetFeatureSet(depth);
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
}
// after this each thread's stemp will get the best candidates, aggregate results
@@ -716,7 +716,7 @@ class ColMaker: public TreeUpdater {
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size());
@@ -846,7 +846,7 @@ class DistColMaker : public ColMaker {
boolmap_[j] = 0;
}
}
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size());

View File

@@ -734,7 +734,7 @@ class GPUMaker : public TreeUpdater {
fId->reserve(n_cols_ * n_rows_);
// in case you end up with a DMatrix having no column access
// then make sure to enable that before copying the data!
for (const auto& batch : dmat->GetSortedColumnBatches()) {
for (const auto& batch : dmat->GetBatches<SortedCSCPage>()) {
for (int i = 0; i < batch.Size(); i++) {
auto col = batch[i];
for (const Entry& e : col) {

View File

@@ -1382,7 +1382,7 @@ class GPUHistMakerSpecialised {
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(shards_);
for (const auto &batch : dmat->GetRowBatches()) {
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
hist_builder_row_state.BeginBatch(batch);
dh::ExecuteIndexShards(

View File

@@ -351,7 +351,7 @@ class CQHistMaker: public HistMaker {
auto lazy_get_hist = [&]() {
thread_hist_.resize(omp_get_max_threads());
// start accumulating statistics
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(fset.size());
#pragma omp parallel for schedule(dynamic, 1)
@@ -425,7 +425,7 @@ class CQHistMaker: public HistMaker {
work_set_.resize(std::unique(work_set_.begin(), work_set_.end()) - work_set_.begin());
// start accumulating statistics
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
@@ -707,7 +707,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
// start accumulating statistics
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);

View File

@@ -56,7 +56,7 @@ class TreeRefresher: public TreeUpdater {
auto lazy_get_stats = [&]() {
const MetaInfo &info = p_fmat->Info();
// start accumulating statistics
for (const auto &batch : p_fmat->GetRowBatches()) {
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)

View File

@@ -135,7 +135,7 @@ class SketchMaker: public BaseMaker {
// number of rows in
const size_t nrows = p_fmat->Info().num_row_;
// start accumulating statistics
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(dynamic, 1)