Refactor DMatrix to return batches of different page types (#4686)
* Use explicit template parameter for specifying page type.
This commit is contained in:
@@ -45,7 +45,7 @@ class BaseMaker: public TreeUpdater {
|
||||
std::fill(fminmax_.begin(), fminmax_.end(),
|
||||
-std::numeric_limits<bst_float>::max());
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
|
||||
auto c = batch[fid];
|
||||
if (c.size() != 0) {
|
||||
@@ -302,7 +302,7 @@ class BaseMaker: public TreeUpdater {
|
||||
const RegTree &tree) {
|
||||
std::vector<unsigned> fsplits;
|
||||
this->GetSplitSet(nodes, tree, &fsplits);
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
for (auto fid : fsplits) {
|
||||
auto col = batch[fid];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
|
||||
@@ -637,7 +637,7 @@ class ColMaker: public TreeUpdater {
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
auto feat_set = column_sampler_.GetFeatureSet(depth);
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
|
||||
}
|
||||
// after this each thread's stemp will get the best candidates, aggregate results
|
||||
@@ -716,7 +716,7 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
for (auto fid : fsplits) {
|
||||
auto col = batch[fid];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
@@ -846,7 +846,7 @@ class DistColMaker : public ColMaker {
|
||||
boolmap_[j] = 0;
|
||||
}
|
||||
}
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
for (auto fid : fsplits) {
|
||||
auto col = batch[fid];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
|
||||
@@ -734,7 +734,7 @@ class GPUMaker : public TreeUpdater {
|
||||
fId->reserve(n_cols_ * n_rows_);
|
||||
// in case you end up with a DMatrix having no column access
|
||||
// then make sure to enable that before copying the data!
|
||||
for (const auto& batch : dmat->GetSortedColumnBatches()) {
|
||||
for (const auto& batch : dmat->GetBatches<SortedCSCPage>()) {
|
||||
for (int i = 0; i < batch.Size(); i++) {
|
||||
auto col = batch[i];
|
||||
for (const Entry& e : col) {
|
||||
|
||||
@@ -1382,7 +1382,7 @@ class GPUHistMakerSpecialised {
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
DeviceHistogramBuilderState hist_builder_row_state(shards_);
|
||||
for (const auto &batch : dmat->GetRowBatches()) {
|
||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
hist_builder_row_state.BeginBatch(batch);
|
||||
|
||||
dh::ExecuteIndexShards(
|
||||
|
||||
@@ -351,7 +351,7 @@ class CQHistMaker: public HistMaker {
|
||||
auto lazy_get_hist = [&]() {
|
||||
thread_hist_.resize(omp_get_max_threads());
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(fset.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
@@ -425,7 +425,7 @@ class CQHistMaker: public HistMaker {
|
||||
work_set_.resize(std::unique(work_set_.begin(), work_set_.end()) - work_set_.begin());
|
||||
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
|
||||
|
||||
@@ -707,7 +707,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
|
||||
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
|
||||
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
auto lazy_get_stats = [&]() {
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
|
||||
const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
|
||||
@@ -135,7 +135,7 @@ class SketchMaker: public BaseMaker {
|
||||
// number of rows in
|
||||
const size_t nrows = p_fmat->Info().num_row_;
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
|
||||
Reference in New Issue
Block a user