Support column split in histogram builder (#8811)
This commit is contained in:
@@ -29,6 +29,7 @@ class HistogramBuilder {
|
||||
size_t n_batches_{0};
|
||||
// Whether XGBoost is running in distributed environment.
|
||||
bool is_distributed_{false};
|
||||
bool is_col_split_{false};
|
||||
|
||||
public:
|
||||
/**
|
||||
@@ -40,7 +41,7 @@ class HistogramBuilder {
|
||||
* of using global rabit variable.
|
||||
*/
|
||||
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
|
||||
bool is_distributed) {
|
||||
bool is_distributed, bool is_col_split) {
|
||||
CHECK_GE(n_threads, 1);
|
||||
n_threads_ = n_threads;
|
||||
n_batches_ = n_batches;
|
||||
@@ -50,6 +51,7 @@ class HistogramBuilder {
|
||||
buffer_.Init(total_bins);
|
||||
builder_ = common::GHistBuilder(total_bins);
|
||||
is_distributed_ = is_distributed;
|
||||
is_col_split_ = is_col_split;
|
||||
// Workaround s390x gcc 7.5.0
|
||||
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
|
||||
}
|
||||
@@ -130,7 +132,7 @@ class HistogramBuilder {
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_distributed_) {
|
||||
if (is_distributed_ && !is_col_split_) {
|
||||
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick,
|
||||
starting_index, sync_count);
|
||||
|
||||
@@ -76,7 +76,7 @@ class GloablApproxBuilder {
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
||||
collective::IsDistributed());
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed());
|
||||
collective::IsDistributed(), fmat->IsColumnSplit());
|
||||
|
||||
auto m_gpair =
|
||||
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
||||
|
||||
Reference in New Issue
Block a user