Support column split in histogram builder (#8811)

This commit is contained in:
Rong Ou 2023-02-17 06:37:01 -08:00 committed by GitHub
parent 40fd3d6d5f
commit a65ad0bd9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 38 additions and 22 deletions

View File

@ -529,6 +529,11 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_; return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
} }
/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return Info().data_split_mode == DataSplitMode::kCol;
}
/*! /*!
* \brief Load DMatrix from URI. * \brief Load DMatrix from URI.
* \param uri The URI of input. * \param uri The URI of input.

View File

@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) { if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced, HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), HostSketchContainer::UseGroup(info),
m->Info().data_split_mode == DataSplitMode::kCol, n_threads); m->IsColumnSplit(), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian); container.PushRowPage(page, info, hessian);
} }
@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
} else { } else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced, SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), HostSketchContainer::UseGroup(info),
m->Info().data_split_mode == DataSplitMode::kCol, n_threads}; m->IsColumnSplit(), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) { for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian); container.PushColPage(page, info, hessian);
} }

View File

@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft); SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{ p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(), batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()}); proxy->IsColumnSplit(), ctx_.Threads()});
} }
HostAdapterDispatch(proxy, [&](auto const& batch) { HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i]; proxy->Info().num_nonzero_ = batch_nnz[i];

View File

@ -584,7 +584,7 @@ class CPUPredictor : public Predictor {
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds, void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) { if (p_fmat->IsColumnSplit()) {
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end); ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds); helper.PredictDMatrix(p_fmat, out_preds);
return; return;

View File

@ -29,6 +29,7 @@ class HistogramBuilder {
size_t n_batches_{0}; size_t n_batches_{0};
// Whether XGBoost is running in distributed environment. // Whether XGBoost is running in distributed environment.
bool is_distributed_{false}; bool is_distributed_{false};
bool is_col_split_{false};
public: public:
/** /**
@ -40,7 +41,7 @@ class HistogramBuilder {
* of using global rabit variable. * of using global rabit variable.
*/ */
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches, void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
bool is_distributed) { bool is_distributed, bool is_col_split) {
CHECK_GE(n_threads, 1); CHECK_GE(n_threads, 1);
n_threads_ = n_threads; n_threads_ = n_threads;
n_batches_ = n_batches; n_batches_ = n_batches;
@ -50,6 +51,7 @@ class HistogramBuilder {
buffer_.Init(total_bins); buffer_.Init(total_bins);
builder_ = common::GHistBuilder(total_bins); builder_ = common::GHistBuilder(total_bins);
is_distributed_ = is_distributed; is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
// Workaround s390x gcc 7.5.0 // Workaround s390x gcc 7.5.0
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce; auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
} }
@ -130,7 +132,7 @@ class HistogramBuilder {
return; return;
} }
if (is_distributed_) { if (is_distributed_ && !is_col_split_) {
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, nodes_for_subtraction_trick,
starting_index, sync_count); starting_index, sync_count);

View File

@ -76,7 +76,7 @@ class GloablApproxBuilder {
} }
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_, histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
collective::IsDistributed()); collective::IsDistributed(), p_fmat->IsColumnSplit());
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }

View File

@ -281,7 +281,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
++page_id; ++page_id;
} }
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed()); collective::IsDistributed(), fmat->IsColumnSplit());
auto m_gpair = auto m_gpair =
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id); linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);

View File

@ -48,7 +48,7 @@ void TestAddHistRows(bool is_distributed) {
HistogramBuilder<CPUExpandEntry> histogram_builder; HistogramBuilder<CPUExpandEntry> histogram_builder;
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
is_distributed); is_distributed, false);
histogram_builder.AddHistRows(&starting_index, &sync_count, histogram_builder.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree); nodes_for_subtraction_trick_, &tree);
@ -86,7 +86,7 @@ void TestSyncHist(bool is_distributed) {
HistogramBuilder<CPUExpandEntry> histogram; HistogramBuilder<CPUExpandEntry> histogram;
uint32_t total_bins = gmat.cut.Ptrs().back(); uint32_t total_bins = gmat.cut.Ptrs().back();
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false);
common::RowSetCollection row_set_collection_; common::RowSetCollection row_set_collection_;
{ {
@ -226,11 +226,14 @@ TEST(CPUHistogram, SyncHist) {
TestSyncHist(false); TestSyncHist(false);
} }
void TestBuildHistogram(bool is_distributed, bool force_read_by_column) { void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
size_t constexpr kNRows = 8, kNCols = 16; size_t constexpr kNRows = 8, kNCols = 16;
int32_t constexpr kMaxBins = 4; int32_t constexpr kMaxBins = 4;
auto p_fmat = auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); if (is_col_split) {
p_fmat = std::shared_ptr<DMatrix>{
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
}
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin()); auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
uint32_t total_bins = gmat.cut.Ptrs().back(); uint32_t total_bins = gmat.cut.Ptrs().back();
@ -241,7 +244,8 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
bst_node_t nid = 0; bst_node_t nid = 0;
HistogramBuilder<CPUExpandEntry> histogram; HistogramBuilder<CPUExpandEntry> histogram;
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed,
is_col_split);
RegTree tree; RegTree tree;
@ -284,11 +288,16 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
} }
TEST(CPUHistogram, BuildHist) { TEST(CPUHistogram, BuildHist) {
TestBuildHistogram(true, false); TestBuildHistogram(true, false, false);
TestBuildHistogram(false, false); TestBuildHistogram(false, false, false);
TestBuildHistogram(true, true); TestBuildHistogram(true, true, false);
TestBuildHistogram(false, true); TestBuildHistogram(false, true, false);
}
TEST(CPUHistogram, BuildHistColSplit) {
auto constexpr kWorkers = 4;
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true);
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true);
} }
namespace { namespace {
@ -340,7 +349,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
HistogramBuilder<CPUExpandEntry> cat_hist; HistogramBuilder<CPUExpandEntry> cat_hist;
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) { for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
auto total_bins = gidx.cut.TotalBins(); auto total_bins = gidx.cut.TotalBins();
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
cat_hist.BuildHist(0, gidx, &tree, row_set_collection, cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
nodes_for_explicit_hist_build, {}, gpair.HostVector(), nodes_for_explicit_hist_build, {}, gpair.HostVector(),
force_read_by_column); force_read_by_column);
@ -354,7 +363,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
HistogramBuilder<CPUExpandEntry> onehot_hist; HistogramBuilder<CPUExpandEntry> onehot_hist;
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) { for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
auto total_bins = gidx.cut.TotalBins(); auto total_bins = gidx.cut.TotalBins();
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
gpair.HostVector(), gpair.HostVector(),
force_read_by_column); force_read_by_column);
@ -419,7 +428,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
256}; 256};
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false); multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false, false);
size_t page_idx{0}; size_t page_idx{0};
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) { for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
@ -440,7 +449,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
InitRowPartitionForTest(&row_set_collection, n_samples); InitRowPartitionForTest(&row_set_collection, n_samples);
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false); single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false, false);
SparsePage concat; SparsePage concat;
std::vector<float> hess(m->Info().num_row_, 1.0f); std::vector<float> hess(m->Info().num_row_, 1.0f);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {