Support column split in multi-target hist (#9171)

This commit is contained in:
Rong Ou
2023-05-26 01:56:05 -07:00
committed by GitHub
parent acd363033e
commit 5b69534b43
17 changed files with 386 additions and 96 deletions

View File

@@ -209,7 +209,7 @@ class PartitionBuilder {
BitVector* decision_bits, BitVector* missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree[nid].SplitIndex();
bst_feature_t fid = tree.SplitIndex(nid);
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);
auto const& cut_values = gmat.cut.Values();
@@ -270,7 +270,7 @@ class PartitionBuilder {
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid;
bool default_left = tree[nid].DefaultLeft();
bool default_left = tree.DefaultLeft(nid);
auto pred = [&](auto ridx) {
bool go_left = default_left;

View File

@@ -7,7 +7,6 @@
#include <utility>
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../data/adapter.h"
#include "categorical.h"
#include "hist_util.h"
@@ -143,6 +142,7 @@ struct QuantileAllreduce {
template <typename WQSketch>
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches) {
@@ -168,7 +168,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
// Gather all column pointers
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size());
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
@@ -196,7 +196,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
"Unexpected size of sketch entry.");
collective::Allreduce<collective::Operation::kSum>(
collective::GlobalSum(
info,
reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
}
@@ -222,8 +223,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
global_feat_ptrs.size());
collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size());
// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
@@ -236,8 +236,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
global_worker_ptr.size());
collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size());
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
@@ -249,8 +248,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
global_categories.size());
collective::GlobalSum(info, global_categories.data(), global_categories.size());
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
@@ -323,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
std::vector<typename WQSketch::Entry> global_sketches;
this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, &global_sketches);
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
@@ -371,7 +369,9 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
InvalidCategory();
}
auto &cut_values = cuts->cut_values_.HostVector();
auto max_cat = *std::max_element(categories.cbegin(), categories.cend());
// With column-wise data split, the categories may be empty.
auto max_cat =
categories.empty() ? 0.0f : *std::max_element(categories.cbegin(), categories.cend());
CheckMaxCat(max_cat, categories.size());
for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) {
cut_values.push_back(i);

View File

@@ -822,7 +822,8 @@ class SketchContainerImpl {
return group_ind;
}
// Gather sketches from all workers.
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
void GatherSketchInfo(MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<bst_row_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches);