Support column split in histogram builder (#8811)

This commit is contained in:
Rong Ou
2023-02-17 06:37:01 -08:00
committed by GitHub
parent 40fd3d6d5f
commit a65ad0bd9c
8 changed files with 38 additions and 22 deletions

View File

@@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->Info().data_split_mode == DataSplitMode::kCol, n_threads);
m->IsColumnSplit(), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
@@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->Info().data_split_mode == DataSplitMode::kCol, n_threads};
m->IsColumnSplit(), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}

View File

@@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
proxy->IsColumnSplit(), ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];

View File

@@ -584,7 +584,7 @@ class CPUPredictor : public Predictor {
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) {
if (p_fmat->IsColumnSplit()) {
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds);
return;

View File

@@ -29,6 +29,7 @@ class HistogramBuilder {
size_t n_batches_{0};
// Whether XGBoost is running in distributed environment.
bool is_distributed_{false};
bool is_col_split_{false};
public:
/**
@@ -40,7 +41,7 @@ class HistogramBuilder {
* of using global rabit variable.
*/
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
bool is_distributed) {
bool is_distributed, bool is_col_split) {
CHECK_GE(n_threads, 1);
n_threads_ = n_threads;
n_batches_ = n_batches;
@@ -50,6 +51,7 @@ class HistogramBuilder {
buffer_.Init(total_bins);
builder_ = common::GHistBuilder(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
// Workaround s390x gcc 7.5.0
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
}
@@ -130,7 +132,7 @@ class HistogramBuilder {
return;
}
if (is_distributed_) {
if (is_distributed_ && !is_col_split_) {
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick,
starting_index, sync_count);

View File

@@ -76,7 +76,7 @@ class GloablApproxBuilder {
}
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
collective::IsDistributed());
collective::IsDistributed(), p_fmat->IsColumnSplit());
monitor_->Stop(__func__);
}

View File

@@ -281,7 +281,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed());
collective::IsDistributed(), fmat->IsColumnSplit());
auto m_gpair =
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);