Support column split in approx tree method (#8847)

This commit is contained in:
Rong Ou
2023-03-01 11:59:07 -08:00
committed by GitHub
parent 6d8afb2218
commit 7cbaee9916
6 changed files with 101 additions and 19 deletions

View File

@@ -38,6 +38,7 @@ class HistEvaluator {
TrainParam param_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
bool is_col_split_{false};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;
@@ -355,7 +356,24 @@ class HistEvaluator {
tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const num_entries = entries.size();
std::vector<ExpandEntry> buffer{num_entries * world};
std::copy_n(entries.cbegin(), num_entries, buffer.begin() + num_entries * rank);
collective::Allgather(buffer.data(), buffer.size() * sizeof(ExpandEntry));
for (auto worker = 0; worker < world; ++worker) {
for (auto nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(buffer[worker * num_entries + nidx_in_set].split);
}
}
}
}
// Add splits to tree, handles all statistic
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator();
@@ -429,7 +447,8 @@ class HistEvaluator {
std::shared_ptr<common::ColumnSampler> sampler)
: ctx_{ctx}, param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId} {
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_.colsample_bynode, param_.colsample_bylevel,

View File

@@ -98,7 +98,7 @@ class HistogramBuilder {
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
if (is_distributed_) {
if (is_distributed_ && !is_col_split_) {
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
} else {

View File

@@ -90,7 +90,9 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
if (p_fmat->IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
}
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = ConstructHistSpace(partitioner_, nodes);