diff --git a/include/xgboost/data.h b/include/xgboost/data.h index bcace4656..00228b145 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -529,6 +529,11 @@ class DMatrix { return Info().num_nonzero_ == Info().num_row_ * Info().num_col_; } + /*! \brief Whether the data is split column-wise. */ + bool IsColumnSplit() const { + return Info().data_split_mode == DataSplitMode::kCol; + } + /*! * \brief Load DMatrix from URI. * \param uri The URI of input. diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 3b4d42a8d..6e83c084e 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b if (!use_sorted) { HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced, HostSketchContainer::UseGroup(info), - m->Info().data_split_mode == DataSplitMode::kCol, n_threads); + m->IsColumnSplit(), n_threads); for (auto const& page : m->GetBatches()) { container.PushRowPage(page, info, hessian); } @@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b } else { SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced, HostSketchContainer::UseGroup(info), - m->Info().data_split_mode == DataSplitMode::kCol, n_threads}; + m->IsColumnSplit(), n_threads}; for (auto const& page : m->GetBatches()) { container.PushColPage(page, info, hessian); } diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 472227e38..ae0cfc4a4 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, SyncFeatureType(&h_ft); p_sketch.reset(new common::HostSketchContainer{ batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(), - proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()}); + proxy->IsColumnSplit(), ctx_.Threads()}); } HostAdapterDispatch(proxy, [&](auto const& batch) { proxy->Info().num_nonzero_ = batch_nnz[i]; diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 2f578fae7..7efaa915a 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -584,7 +584,7 @@ class CPUPredictor : public Predictor { void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { - if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) { + if (p_fmat->IsColumnSplit()) { ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end); helper.PredictDMatrix(p_fmat, out_preds); return; diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index f3ed27a88..df5a85633 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -29,6 +29,7 @@ class HistogramBuilder { size_t n_batches_{0}; // Whether XGBoost is running in distributed environment. bool is_distributed_{false}; + bool is_col_split_{false}; public: /** @@ -40,7 +41,7 @@ class HistogramBuilder { * of using global rabit variable. */ void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches, - bool is_distributed) { + bool is_distributed, bool is_col_split) { CHECK_GE(n_threads, 1); n_threads_ = n_threads; n_batches_ = n_batches; @@ -50,6 +51,7 @@ class HistogramBuilder { buffer_.Init(total_bins); builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; + is_col_split_ = is_col_split; // Workaround s390x gcc 7.5.0 auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce; } @@ -130,7 +132,7 @@ class HistogramBuilder { return; } - if (is_distributed_) { + if (is_distributed_ && !is_col_split_) { this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, starting_index, sync_count); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 4852e325f..da5a7cf88 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -76,7 +76,7 @@ class GloablApproxBuilder { } histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_, - collective::IsDistributed()); + collective::IsDistributed(), p_fmat->IsColumnSplit()); monitor_->Stop(__func__); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index ad2e57aa9..21732ce56 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -281,7 +281,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, ++page_id; } histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, - collective::IsDistributed()); + collective::IsDistributed(), fmat->IsColumnSplit()); auto m_gpair = linalg::MakeTensorView(*gpair, {gpair->size(), static_cast(1)}, ctx_->gpu_id); diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 1e37f1cd4..8462fa7d5 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -48,7 +48,7 @@ void TestAddHistRows(bool is_distributed) { HistogramBuilder histogram_builder; histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, - is_distributed); + is_distributed, false); histogram_builder.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -86,7 +86,7 @@ void TestSyncHist(bool is_distributed) { HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); - histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); + histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false); common::RowSetCollection row_set_collection_; { @@ -226,11 +226,14 @@ TEST(CPUHistogram, SyncHist) { TestSyncHist(false); } -void TestBuildHistogram(bool is_distributed, bool force_read_by_column) { +void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) { size_t constexpr kNRows = 8, kNCols = 16; int32_t constexpr kMaxBins = 4; - auto p_fmat = - RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); + auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); + if (is_col_split) { + p_fmat = std::shared_ptr{ + p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())}; + } auto const &gmat = *(p_fmat->GetBatches(BatchParam{kMaxBins, 0.5}).begin()); uint32_t total_bins = gmat.cut.Ptrs().back(); @@ -241,7 +244,8 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) { bst_node_t nid = 0; HistogramBuilder histogram; - histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); + histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, + is_col_split); RegTree tree; @@ -284,11 +288,16 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) { } TEST(CPUHistogram, BuildHist) { - TestBuildHistogram(true, false); - TestBuildHistogram(false, false); - TestBuildHistogram(true, true); - TestBuildHistogram(false, true); + TestBuildHistogram(true, false, false); + TestBuildHistogram(false, false, false); + TestBuildHistogram(true, true, false); + TestBuildHistogram(false, true, false); +} +TEST(CPUHistogram, BuildHistColSplit) { + auto constexpr kWorkers = 4; + RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true); + RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true); } namespace { @@ -340,7 +349,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) { HistogramBuilder cat_hist; for (auto const &gidx : cat_m->GetBatches({kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); - cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); + cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false); cat_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, gpair.HostVector(), force_read_by_column); @@ -354,7 +363,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) { HistogramBuilder onehot_hist; for (auto const &gidx : encode_m->GetBatches({kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); - onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); + onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false); onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, gpair.HostVector(), force_read_by_column); @@ -419,7 +428,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, 256}; - multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false); + multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false, false); size_t page_idx{0}; for (auto const &page : m->GetBatches(batch_param)) { @@ -440,7 +449,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo common::RowSetCollection row_set_collection; InitRowPartitionForTest(&row_set_collection, n_samples); - single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false); + single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false, false); SparsePage concat; std::vector hess(m->Info().num_row_, 1.0f); for (auto const& page : m->GetBatches()) {