Support column split in multi-target hist (#9171)
This commit is contained in:
@@ -209,7 +209,7 @@ class PartitionBuilder {
|
||||
BitVector* decision_bits, BitVector* missing_bits) {
|
||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||
std::size_t nid = nodes[node_in_set].nid;
|
||||
bst_feature_t fid = tree[nid].SplitIndex();
|
||||
bst_feature_t fid = tree.SplitIndex(nid);
|
||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||
auto node_cats = tree.NodeCats(nid);
|
||||
auto const& cut_values = gmat.cut.Values();
|
||||
@@ -270,7 +270,7 @@ class PartitionBuilder {
|
||||
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||
std::size_t nid = nodes[node_in_set].nid;
|
||||
bool default_left = tree[nid].DefaultLeft();
|
||||
bool default_left = tree.DefaultLeft(nid);
|
||||
|
||||
auto pred = [&](auto ridx) {
|
||||
bool go_left = default_left;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "categorical.h"
|
||||
#include "hist_util.h"
|
||||
@@ -143,6 +142,7 @@ struct QuantileAllreduce {
|
||||
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
MetaInfo const& info,
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||
@@ -168,7 +168,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
||||
|
||||
// Gather all column pointers
|
||||
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
|
||||
collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size());
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
size_t back = (i + 1) * (n_columns + 1) - 1;
|
||||
auto n_entries = sketches_scan.at(back);
|
||||
@@ -196,7 +196,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
|
||||
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
|
||||
"Unexpected size of sketch entry.");
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
collective::GlobalSum(
|
||||
info,
|
||||
reinterpret_cast<float *>(global_sketches.data()),
|
||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
||||
}
|
||||
@@ -222,8 +223,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
||||
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
||||
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
||||
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
||||
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
|
||||
global_feat_ptrs.size());
|
||||
collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size());
|
||||
|
||||
// move all categories into a flatten vector to prepare for allreduce
|
||||
size_t total = feature_ptr.back();
|
||||
@@ -236,8 +236,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
||||
// indptr for indexing workers
|
||||
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
||||
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
||||
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
|
||||
global_worker_ptr.size());
|
||||
collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size());
|
||||
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
|
||||
// total number of categories in all workers with all features
|
||||
auto gtotal = global_worker_ptr.back();
|
||||
@@ -249,8 +248,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
||||
CHECK_EQ(rank_size, total);
|
||||
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
||||
// gather values from all workers.
|
||||
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
|
||||
global_categories.size());
|
||||
collective::GlobalSum(info, global_categories.data(), global_categories.size());
|
||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||
categories_.size()};
|
||||
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
|
||||
@@ -323,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
||||
|
||||
std::vector<typename WQSketch::Entry> global_sketches;
|
||||
this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||
|
||||
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
||||
|
||||
@@ -371,7 +369,9 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
||||
InvalidCategory();
|
||||
}
|
||||
auto &cut_values = cuts->cut_values_.HostVector();
|
||||
auto max_cat = *std::max_element(categories.cbegin(), categories.cend());
|
||||
// With column-wise data split, the categories may be empty.
|
||||
auto max_cat =
|
||||
categories.empty() ? 0.0f : *std::max_element(categories.cbegin(), categories.cend());
|
||||
CheckMaxCat(max_cat, categories.size());
|
||||
for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) {
|
||||
cut_values.push_back(i);
|
||||
|
||||
@@ -822,7 +822,8 @@ class SketchContainerImpl {
|
||||
return group_ind;
|
||||
}
|
||||
// Gather sketches from all workers.
|
||||
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
void GatherSketchInfo(MetaInfo const& info,
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<bst_row_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||
|
||||
Reference in New Issue
Block a user