Support column split in approx tree method (#8847)
This commit is contained in:
@@ -912,6 +912,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
}
|
||||
LOG(CONSOLE) << "Splitting data by column";
|
||||
auto* sliced = dmat->SliceCol(npart, partid);
|
||||
delete dmat;
|
||||
return sliced;
|
||||
|
||||
@@ -38,6 +38,7 @@ class HistEvaluator {
|
||||
TrainParam param_;
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
TreeEvaluator tree_evaluator_;
|
||||
bool is_col_split_{false};
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
std::vector<NodeEntry> snode_;
|
||||
|
||||
@@ -355,7 +356,24 @@ class HistEvaluator {
|
||||
tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_col_split_) {
|
||||
// With column-wise data split, we gather the best splits from all the workers and update the
|
||||
// expand entries accordingly.
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const num_entries = entries.size();
|
||||
std::vector<ExpandEntry> buffer{num_entries * world};
|
||||
std::copy_n(entries.cbegin(), num_entries, buffer.begin() + num_entries * rank);
|
||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(ExpandEntry));
|
||||
for (auto worker = 0; worker < world; ++worker) {
|
||||
for (auto nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||
entries[nidx_in_set].split.Update(buffer[worker * num_entries + nidx_in_set].split);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add splits to tree, handles all statistic
|
||||
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
@@ -429,7 +447,8 @@ class HistEvaluator {
|
||||
std::shared_ptr<common::ColumnSampler> sampler)
|
||||
: ctx_{ctx}, param_{param},
|
||||
column_sampler_{std::move(sampler)},
|
||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId} {
|
||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
||||
interaction_constraints_.Configure(param, info.num_col_);
|
||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
|
||||
@@ -98,7 +98,7 @@ class HistogramBuilder {
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
RegTree const *p_tree) {
|
||||
if (is_distributed_) {
|
||||
if (is_distributed_ && !is_col_split_) {
|
||||
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick, p_tree);
|
||||
} else {
|
||||
|
||||
@@ -90,7 +90,9 @@ class GloablApproxBuilder {
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
if (p_fmat->IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
}
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
size_t i = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||
|
||||
Reference in New Issue
Block a user