Support column split in histogram builder (#8811)
This commit is contained in:
parent
40fd3d6d5f
commit
a65ad0bd9c
@ -529,6 +529,11 @@ class DMatrix {
|
|||||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*! \brief Whether the data is split column-wise. */
|
||||||
|
bool IsColumnSplit() const {
|
||||||
|
return Info().data_split_mode == DataSplitMode::kCol;
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Load DMatrix from URI.
|
* \brief Load DMatrix from URI.
|
||||||
* \param uri The URI of input.
|
* \param uri The URI of input.
|
||||||
|
|||||||
@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
|||||||
if (!use_sorted) {
|
if (!use_sorted) {
|
||||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||||
HostSketchContainer::UseGroup(info),
|
HostSketchContainer::UseGroup(info),
|
||||||
m->Info().data_split_mode == DataSplitMode::kCol, n_threads);
|
m->IsColumnSplit(), n_threads);
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
container.PushRowPage(page, info, hessian);
|
container.PushRowPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
|||||||
} else {
|
} else {
|
||||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||||
HostSketchContainer::UseGroup(info),
|
HostSketchContainer::UseGroup(info),
|
||||||
m->Info().data_split_mode == DataSplitMode::kCol, n_threads};
|
m->IsColumnSplit(), n_threads};
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||||
container.PushColPage(page, info, hessian);
|
container.PushColPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
SyncFeatureType(&h_ft);
|
SyncFeatureType(&h_ft);
|
||||||
p_sketch.reset(new common::HostSketchContainer{
|
p_sketch.reset(new common::HostSketchContainer{
|
||||||
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
||||||
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
|
proxy->IsColumnSplit(), ctx_.Threads()});
|
||||||
}
|
}
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
|
|||||||
@ -584,7 +584,7 @@ class CPUPredictor : public Predictor {
|
|||||||
|
|
||||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||||
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) {
|
if (p_fmat->IsColumnSplit()) {
|
||||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||||
helper.PredictDMatrix(p_fmat, out_preds);
|
helper.PredictDMatrix(p_fmat, out_preds);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -29,6 +29,7 @@ class HistogramBuilder {
|
|||||||
size_t n_batches_{0};
|
size_t n_batches_{0};
|
||||||
// Whether XGBoost is running in distributed environment.
|
// Whether XGBoost is running in distributed environment.
|
||||||
bool is_distributed_{false};
|
bool is_distributed_{false};
|
||||||
|
bool is_col_split_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
@ -40,7 +41,7 @@ class HistogramBuilder {
|
|||||||
* of using global rabit variable.
|
* of using global rabit variable.
|
||||||
*/
|
*/
|
||||||
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
|
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
|
||||||
bool is_distributed) {
|
bool is_distributed, bool is_col_split) {
|
||||||
CHECK_GE(n_threads, 1);
|
CHECK_GE(n_threads, 1);
|
||||||
n_threads_ = n_threads;
|
n_threads_ = n_threads;
|
||||||
n_batches_ = n_batches;
|
n_batches_ = n_batches;
|
||||||
@ -50,6 +51,7 @@ class HistogramBuilder {
|
|||||||
buffer_.Init(total_bins);
|
buffer_.Init(total_bins);
|
||||||
builder_ = common::GHistBuilder(total_bins);
|
builder_ = common::GHistBuilder(total_bins);
|
||||||
is_distributed_ = is_distributed;
|
is_distributed_ = is_distributed;
|
||||||
|
is_col_split_ = is_col_split;
|
||||||
// Workaround s390x gcc 7.5.0
|
// Workaround s390x gcc 7.5.0
|
||||||
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
|
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
|
||||||
}
|
}
|
||||||
@ -130,7 +132,7 @@ class HistogramBuilder {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_distributed_) {
|
if (is_distributed_ && !is_col_split_) {
|
||||||
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
|
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
|
||||||
nodes_for_subtraction_trick,
|
nodes_for_subtraction_trick,
|
||||||
starting_index, sync_count);
|
starting_index, sync_count);
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class GloablApproxBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
||||||
collective::IsDistributed());
|
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -281,7 +281,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
|||||||
++page_id;
|
++page_id;
|
||||||
}
|
}
|
||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
collective::IsDistributed());
|
collective::IsDistributed(), fmat->IsColumnSplit());
|
||||||
|
|
||||||
auto m_gpair =
|
auto m_gpair =
|
||||||
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
||||||
|
|||||||
@ -48,7 +48,7 @@ void TestAddHistRows(bool is_distributed) {
|
|||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
||||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
||||||
is_distributed);
|
is_distributed, false);
|
||||||
histogram_builder.AddHistRows(&starting_index, &sync_count,
|
histogram_builder.AddHistRows(&starting_index, &sync_count,
|
||||||
nodes_for_explicit_hist_build_,
|
nodes_for_explicit_hist_build_,
|
||||||
nodes_for_subtraction_trick_, &tree);
|
nodes_for_subtraction_trick_, &tree);
|
||||||
@ -86,7 +86,7 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> histogram;
|
HistogramBuilder<CPUExpandEntry> histogram;
|
||||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
|
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false);
|
||||||
|
|
||||||
common::RowSetCollection row_set_collection_;
|
common::RowSetCollection row_set_collection_;
|
||||||
{
|
{
|
||||||
@ -226,11 +226,14 @@ TEST(CPUHistogram, SyncHist) {
|
|||||||
TestSyncHist(false);
|
TestSyncHist(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
int32_t constexpr kMaxBins = 4;
|
||||||
auto p_fmat =
|
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
if (is_col_split) {
|
||||||
|
p_fmat = std::shared_ptr<DMatrix>{
|
||||||
|
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
}
|
||||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
||||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||||
|
|
||||||
@ -241,7 +244,8 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
|||||||
|
|
||||||
bst_node_t nid = 0;
|
bst_node_t nid = 0;
|
||||||
HistogramBuilder<CPUExpandEntry> histogram;
|
HistogramBuilder<CPUExpandEntry> histogram;
|
||||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
|
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed,
|
||||||
|
is_col_split);
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|
||||||
@ -284,11 +288,16 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUHistogram, BuildHist) {
|
TEST(CPUHistogram, BuildHist) {
|
||||||
TestBuildHistogram(true, false);
|
TestBuildHistogram(true, false, false);
|
||||||
TestBuildHistogram(false, false);
|
TestBuildHistogram(false, false, false);
|
||||||
TestBuildHistogram(true, true);
|
TestBuildHistogram(true, true, false);
|
||||||
TestBuildHistogram(false, true);
|
TestBuildHistogram(false, true, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPUHistogram, BuildHistColSplit) {
|
||||||
|
auto constexpr kWorkers = 4;
|
||||||
|
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true);
|
||||||
|
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -340,7 +349,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
HistogramBuilder<CPUExpandEntry> cat_hist;
|
HistogramBuilder<CPUExpandEntry> cat_hist;
|
||||||
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
||||||
auto total_bins = gidx.cut.TotalBins();
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
|
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||||
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||||
nodes_for_explicit_hist_build, {}, gpair.HostVector(),
|
nodes_for_explicit_hist_build, {}, gpair.HostVector(),
|
||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
@ -354,7 +363,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
||||||
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
||||||
auto total_bins = gidx.cut.TotalBins();
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
|
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||||
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
||||||
gpair.HostVector(),
|
gpair.HostVector(),
|
||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
@ -419,7 +428,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
||||||
256};
|
256};
|
||||||
|
|
||||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false);
|
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false, false);
|
||||||
|
|
||||||
size_t page_idx{0};
|
size_t page_idx{0};
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
||||||
@ -440,7 +449,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
common::RowSetCollection row_set_collection;
|
common::RowSetCollection row_set_collection;
|
||||||
InitRowPartitionForTest(&row_set_collection, n_samples);
|
InitRowPartitionForTest(&row_set_collection, n_samples);
|
||||||
|
|
||||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false);
|
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false, false);
|
||||||
SparsePage concat;
|
SparsePage concat;
|
||||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user