Support column split in histogram builder (#8811)
This commit is contained in:
parent
40fd3d6d5f
commit
a65ad0bd9c
@ -529,6 +529,11 @@ class DMatrix {
|
||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||
}
|
||||
|
||||
/*! \brief Whether the data is split column-wise. */
|
||||
bool IsColumnSplit() const {
|
||||
return Info().data_split_mode == DataSplitMode::kCol;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Load DMatrix from URI.
|
||||
* \param uri The URI of input.
|
||||
|
||||
@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info),
|
||||
m->Info().data_split_mode == DataSplitMode::kCol, n_threads);
|
||||
m->IsColumnSplit(), n_threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
||||
} else {
|
||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info),
|
||||
m->Info().data_split_mode == DataSplitMode::kCol, n_threads};
|
||||
m->IsColumnSplit(), n_threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
|
||||
@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
SyncFeatureType(&h_ft);
|
||||
p_sketch.reset(new common::HostSketchContainer{
|
||||
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
||||
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
|
||||
proxy->IsColumnSplit(), ctx_.Threads()});
|
||||
}
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||
|
||||
@ -584,7 +584,7 @@ class CPUPredictor : public Predictor {
|
||||
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) {
|
||||
if (p_fmat->IsColumnSplit()) {
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||
helper.PredictDMatrix(p_fmat, out_preds);
|
||||
return;
|
||||
|
||||
@ -29,6 +29,7 @@ class HistogramBuilder {
|
||||
size_t n_batches_{0};
|
||||
// Whether XGBoost is running in distributed environment.
|
||||
bool is_distributed_{false};
|
||||
bool is_col_split_{false};
|
||||
|
||||
public:
|
||||
/**
|
||||
@ -40,7 +41,7 @@ class HistogramBuilder {
|
||||
* of using global rabit variable.
|
||||
*/
|
||||
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
|
||||
bool is_distributed) {
|
||||
bool is_distributed, bool is_col_split) {
|
||||
CHECK_GE(n_threads, 1);
|
||||
n_threads_ = n_threads;
|
||||
n_batches_ = n_batches;
|
||||
@ -50,6 +51,7 @@ class HistogramBuilder {
|
||||
buffer_.Init(total_bins);
|
||||
builder_ = common::GHistBuilder(total_bins);
|
||||
is_distributed_ = is_distributed;
|
||||
is_col_split_ = is_col_split;
|
||||
// Workaround s390x gcc 7.5.0
|
||||
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
|
||||
}
|
||||
@ -130,7 +132,7 @@ class HistogramBuilder {
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_distributed_) {
|
||||
if (is_distributed_ && !is_col_split_) {
|
||||
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick,
|
||||
starting_index, sync_count);
|
||||
|
||||
@ -76,7 +76,7 @@ class GloablApproxBuilder {
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
||||
collective::IsDistributed());
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
|
||||
@ -281,7 +281,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed());
|
||||
collective::IsDistributed(), fmat->IsColumnSplit());
|
||||
|
||||
auto m_gpair =
|
||||
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
||||
|
||||
@ -48,7 +48,7 @@ void TestAddHistRows(bool is_distributed) {
|
||||
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
||||
is_distributed);
|
||||
is_distributed, false);
|
||||
histogram_builder.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
@ -86,7 +86,7 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
HistogramBuilder<CPUExpandEntry> histogram;
|
||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false);
|
||||
|
||||
common::RowSetCollection row_set_collection_;
|
||||
{
|
||||
@ -226,11 +226,14 @@ TEST(CPUHistogram, SyncHist) {
|
||||
TestSyncHist(false);
|
||||
}
|
||||
|
||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
||||
size_t constexpr kNRows = 8, kNCols = 16;
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
auto p_fmat =
|
||||
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
if (is_col_split) {
|
||||
p_fmat = std::shared_ptr<DMatrix>{
|
||||
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||
}
|
||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||
|
||||
@ -241,7 +244,8 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
||||
|
||||
bst_node_t nid = 0;
|
||||
HistogramBuilder<CPUExpandEntry> histogram;
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed,
|
||||
is_col_split);
|
||||
|
||||
RegTree tree;
|
||||
|
||||
@ -284,11 +288,16 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
||||
}
|
||||
|
||||
TEST(CPUHistogram, BuildHist) {
|
||||
TestBuildHistogram(true, false);
|
||||
TestBuildHistogram(false, false);
|
||||
TestBuildHistogram(true, true);
|
||||
TestBuildHistogram(false, true);
|
||||
TestBuildHistogram(true, false, false);
|
||||
TestBuildHistogram(false, false, false);
|
||||
TestBuildHistogram(true, true, false);
|
||||
TestBuildHistogram(false, true, false);
|
||||
}
|
||||
|
||||
TEST(CPUHistogram, BuildHistColSplit) {
|
||||
auto constexpr kWorkers = 4;
|
||||
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true);
|
||||
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -340,7 +349,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
HistogramBuilder<CPUExpandEntry> cat_hist;
|
||||
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
||||
auto total_bins = gidx.cut.TotalBins();
|
||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
|
||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||
nodes_for_explicit_hist_build, {}, gpair.HostVector(),
|
||||
force_read_by_column);
|
||||
@ -354,7 +363,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
||||
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
||||
auto total_bins = gidx.cut.TotalBins();
|
||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
|
||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
||||
gpair.HostVector(),
|
||||
force_read_by_column);
|
||||
@ -419,7 +428,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
||||
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
||||
256};
|
||||
|
||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false);
|
||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false, false);
|
||||
|
||||
size_t page_idx{0};
|
||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
||||
@ -440,7 +449,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
||||
common::RowSetCollection row_set_collection;
|
||||
InitRowPartitionForTest(&row_set_collection, n_samples);
|
||||
|
||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false);
|
||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false, false);
|
||||
SparsePage concat;
|
||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user