Support column split in multi-target hist (#9171)
This commit is contained in:
parent
acd363033e
commit
5b69534b43
@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "communicator.h"
|
#include "communicator.h"
|
||||||
|
|
||||||
@ -224,5 +225,46 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
|
|||||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct AllgatherVResult {
|
||||||
|
std::vector<std::size_t> offsets;
|
||||||
|
std::vector<std::size_t> sizes;
|
||||||
|
std::vector<T> result;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||||
|
*
|
||||||
|
* We assume each worker has the same number of inputs, but each input may be of a different size.
|
||||||
|
*
|
||||||
|
* @param inputs All the inputs from the local worker.
|
||||||
|
* @param sizes Sizes of each input.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
|
||||||
|
std::vector<std::size_t> const &sizes) {
|
||||||
|
auto num_inputs = sizes.size();
|
||||||
|
|
||||||
|
// Gather the sizes across all workers.
|
||||||
|
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
|
||||||
|
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
|
||||||
|
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
|
||||||
|
|
||||||
|
// Calculate input offsets (std::exclusive_scan).
|
||||||
|
std::vector<std::size_t> offsets(all_sizes.size());
|
||||||
|
for (auto i = 1; i < offsets.size(); i++) {
|
||||||
|
offsets[i] = offsets[i - 1] + all_sizes[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gather all the inputs.
|
||||||
|
auto total_input_size = offsets.back() + all_sizes.back();
|
||||||
|
std::vector<T> all_inputs(total_input_size);
|
||||||
|
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
|
||||||
|
// We cannot use allgather here, since each worker might have a different size.
|
||||||
|
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
|
||||||
|
|
||||||
|
return {offsets, all_sizes, all_inputs};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -209,7 +209,7 @@ class PartitionBuilder {
|
|||||||
BitVector* decision_bits, BitVector* missing_bits) {
|
BitVector* decision_bits, BitVector* missing_bits) {
|
||||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||||
std::size_t nid = nodes[node_in_set].nid;
|
std::size_t nid = nodes[node_in_set].nid;
|
||||||
bst_feature_t fid = tree[nid].SplitIndex();
|
bst_feature_t fid = tree.SplitIndex(nid);
|
||||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||||
auto node_cats = tree.NodeCats(nid);
|
auto node_cats = tree.NodeCats(nid);
|
||||||
auto const& cut_values = gmat.cut.Values();
|
auto const& cut_values = gmat.cut.Values();
|
||||||
@ -270,7 +270,7 @@ class PartitionBuilder {
|
|||||||
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||||
std::size_t nid = nodes[node_in_set].nid;
|
std::size_t nid = nodes[node_in_set].nid;
|
||||||
bool default_left = tree[nid].DefaultLeft();
|
bool default_left = tree.DefaultLeft(nid);
|
||||||
|
|
||||||
auto pred = [&](auto ridx) {
|
auto pred = [&](auto ridx) {
|
||||||
bool go_left = default_left;
|
bool go_left = default_left;
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../collective/aggregator.h"
|
#include "../collective/aggregator.h"
|
||||||
#include "../collective/communicator-inl.h"
|
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
@ -143,6 +142,7 @@ struct QuantileAllreduce {
|
|||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||||
|
MetaInfo const& info,
|
||||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||||
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
||||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||||
@ -168,7 +168,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
|||||||
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
||||||
|
|
||||||
// Gather all column pointers
|
// Gather all column pointers
|
||||||
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
|
collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size());
|
||||||
for (int32_t i = 0; i < world; ++i) {
|
for (int32_t i = 0; i < world; ++i) {
|
||||||
size_t back = (i + 1) * (n_columns + 1) - 1;
|
size_t back = (i + 1) * (n_columns + 1) - 1;
|
||||||
auto n_entries = sketches_scan.at(back);
|
auto n_entries = sketches_scan.at(back);
|
||||||
@ -196,7 +196,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
|||||||
|
|
||||||
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
|
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
|
||||||
"Unexpected size of sketch entry.");
|
"Unexpected size of sketch entry.");
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
collective::GlobalSum(
|
||||||
|
info,
|
||||||
reinterpret_cast<float *>(global_sketches.data()),
|
reinterpret_cast<float *>(global_sketches.data()),
|
||||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
||||||
}
|
}
|
||||||
@ -222,8 +223,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
|||||||
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
||||||
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
||||||
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
||||||
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
|
collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size());
|
||||||
global_feat_ptrs.size());
|
|
||||||
|
|
||||||
// move all categories into a flatten vector to prepare for allreduce
|
// move all categories into a flatten vector to prepare for allreduce
|
||||||
size_t total = feature_ptr.back();
|
size_t total = feature_ptr.back();
|
||||||
@ -236,8 +236,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
|||||||
// indptr for indexing workers
|
// indptr for indexing workers
|
||||||
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
||||||
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
||||||
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
|
collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size());
|
||||||
global_worker_ptr.size());
|
|
||||||
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
|
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
|
||||||
// total number of categories in all workers with all features
|
// total number of categories in all workers with all features
|
||||||
auto gtotal = global_worker_ptr.back();
|
auto gtotal = global_worker_ptr.back();
|
||||||
@ -249,8 +248,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
|||||||
CHECK_EQ(rank_size, total);
|
CHECK_EQ(rank_size, total);
|
||||||
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
||||||
// gather values from all workers.
|
// gather values from all workers.
|
||||||
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
|
collective::GlobalSum(info, global_categories.data(), global_categories.size());
|
||||||
global_categories.size());
|
|
||||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||||
categories_.size()};
|
categories_.size()};
|
||||||
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
|
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
|
||||||
@ -323,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
|||||||
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
||||||
|
|
||||||
std::vector<typename WQSketch::Entry> global_sketches;
|
std::vector<typename WQSketch::Entry> global_sketches;
|
||||||
this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, &global_sketches);
|
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||||
|
|
||||||
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
||||||
|
|
||||||
@ -371,7 +369,9 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
|||||||
InvalidCategory();
|
InvalidCategory();
|
||||||
}
|
}
|
||||||
auto &cut_values = cuts->cut_values_.HostVector();
|
auto &cut_values = cuts->cut_values_.HostVector();
|
||||||
auto max_cat = *std::max_element(categories.cbegin(), categories.cend());
|
// With column-wise data split, the categories may be empty.
|
||||||
|
auto max_cat =
|
||||||
|
categories.empty() ? 0.0f : *std::max_element(categories.cbegin(), categories.cend());
|
||||||
CheckMaxCat(max_cat, categories.size());
|
CheckMaxCat(max_cat, categories.size());
|
||||||
for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) {
|
for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) {
|
||||||
cut_values.push_back(i);
|
cut_values.push_back(i);
|
||||||
|
|||||||
@ -822,7 +822,8 @@ class SketchContainerImpl {
|
|||||||
return group_ind;
|
return group_ind;
|
||||||
}
|
}
|
||||||
// Gather sketches from all workers.
|
// Gather sketches from all workers.
|
||||||
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
void GatherSketchInfo(MetaInfo const& info,
|
||||||
|
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||||
std::vector<bst_row_t> *p_worker_segments,
|
std::vector<bst_row_t> *p_worker_segments,
|
||||||
std::vector<bst_row_t> *p_sketches_scan,
|
std::vector<bst_row_t> *p_sketches_scan,
|
||||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||||
|
|||||||
@ -698,6 +698,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
this->feature_type_names = that.feature_type_names;
|
this->feature_type_names = that.feature_type_names;
|
||||||
auto &h_feature_types = feature_types.HostVector();
|
auto &h_feature_types = feature_types.HostVector();
|
||||||
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||||
|
} else if (!that.feature_types.Empty()) {
|
||||||
|
this->feature_types.Resize(that.feature_types.Size());
|
||||||
|
this->feature_types.Copy(that.feature_types);
|
||||||
}
|
}
|
||||||
if (!that.feature_weights.Empty()) {
|
if (!that.feature_weights.Empty()) {
|
||||||
this->feature_weights.Resize(that.feature_weights.Size());
|
this->feature_weights.Resize(that.feature_weights.Size());
|
||||||
|
|||||||
@ -25,7 +25,6 @@
|
|||||||
#include "xgboost/linalg.h" // for Constants, Vector
|
#include "xgboost/linalg.h" // for Constants, Vector
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
template <typename ExpandEntry>
|
|
||||||
class HistEvaluator {
|
class HistEvaluator {
|
||||||
private:
|
private:
|
||||||
struct NodeEntry {
|
struct NodeEntry {
|
||||||
@ -285,10 +284,42 @@ class HistEvaluator {
|
|||||||
return left_sum;
|
return left_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gather the expand entries from all the workers.
|
||||||
|
* @param entries Local expand entries on this worker.
|
||||||
|
* @return Global expand entries gathered from all workers.
|
||||||
|
*/
|
||||||
|
std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
|
||||||
|
auto const world = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
auto const num_entries = entries.size();
|
||||||
|
|
||||||
|
// First, gather all the primitive fields.
|
||||||
|
std::vector<CPUExpandEntry> all_entries(num_entries * world);
|
||||||
|
std::vector<uint32_t> cat_bits;
|
||||||
|
std::vector<std::size_t> cat_bits_sizes;
|
||||||
|
for (auto i = 0; i < num_entries; i++) {
|
||||||
|
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
|
||||||
|
}
|
||||||
|
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(CPUExpandEntry));
|
||||||
|
|
||||||
|
// Gather all the cat_bits.
|
||||||
|
auto gathered = collective::AllgatherV(cat_bits, cat_bits_sizes);
|
||||||
|
|
||||||
|
common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
|
||||||
|
// Copy the cat_bits back into all expand entries.
|
||||||
|
all_entries[i].split.cat_bits.resize(gathered.sizes[i]);
|
||||||
|
std::copy_n(gathered.result.cbegin() + gathered.offsets[i], gathered.sizes[i],
|
||||||
|
all_entries[i].split.cat_bits.begin());
|
||||||
|
});
|
||||||
|
|
||||||
|
return all_entries;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
|
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
|
||||||
common::Span<FeatureType const> feature_types, const RegTree &tree,
|
common::Span<FeatureType const> feature_types, const RegTree &tree,
|
||||||
std::vector<ExpandEntry> *p_entries) {
|
std::vector<CPUExpandEntry> *p_entries) {
|
||||||
auto n_threads = ctx_->Threads();
|
auto n_threads = ctx_->Threads();
|
||||||
auto& entries = *p_entries;
|
auto& entries = *p_entries;
|
||||||
// All nodes are on the same level, so we can store the shared ptr.
|
// All nodes are on the same level, so we can store the shared ptr.
|
||||||
@ -306,7 +337,7 @@ class HistEvaluator {
|
|||||||
return features[nidx_in_set]->Size();
|
return features[nidx_in_set]->Size();
|
||||||
}, grain_size);
|
}, grain_size);
|
||||||
|
|
||||||
std::vector<ExpandEntry> tloc_candidates(n_threads * entries.size());
|
std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
|
||||||
for (size_t i = 0; i < entries.size(); ++i) {
|
for (size_t i = 0; i < entries.size(); ++i) {
|
||||||
for (decltype(n_threads) j = 0; j < n_threads; ++j) {
|
for (decltype(n_threads) j = 0; j < n_threads; ++j) {
|
||||||
tloc_candidates[i * n_threads + j] = entries[i];
|
tloc_candidates[i * n_threads + j] = entries[i];
|
||||||
@ -365,22 +396,18 @@ class HistEvaluator {
|
|||||||
if (is_col_split_) {
|
if (is_col_split_) {
|
||||||
// With column-wise data split, we gather the best splits from all the workers and update the
|
// With column-wise data split, we gather the best splits from all the workers and update the
|
||||||
// expand entries accordingly.
|
// expand entries accordingly.
|
||||||
auto const world = collective::GetWorldSize();
|
auto all_entries = Allgather(entries);
|
||||||
auto const rank = collective::GetRank();
|
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
|
||||||
auto const num_entries = entries.size();
|
|
||||||
std::vector<ExpandEntry> buffer{num_entries * world};
|
|
||||||
std::copy_n(entries.cbegin(), num_entries, buffer.begin() + num_entries * rank);
|
|
||||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(ExpandEntry));
|
|
||||||
for (auto worker = 0; worker < world; ++worker) {
|
|
||||||
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||||
entries[nidx_in_set].split.Update(buffer[worker * num_entries + nidx_in_set].split);
|
entries[nidx_in_set].split.Update(
|
||||||
|
all_entries[worker * entries.size() + nidx_in_set].split);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add splits to tree, handles all statistic
|
// Add splits to tree, handles all statistic
|
||||||
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
|
void ApplyTreeSplit(CPUExpandEntry const& candidate, RegTree *p_tree) {
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
|
|
||||||
@ -465,6 +492,7 @@ class HistMultiEvaluator {
|
|||||||
FeatureInteractionConstraintHost interaction_constraints_;
|
FeatureInteractionConstraintHost interaction_constraints_;
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
Context const *ctx_;
|
Context const *ctx_;
|
||||||
|
bool is_col_split_{false};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static double MultiCalcSplitGain(TrainParam const ¶m,
|
static double MultiCalcSplitGain(TrainParam const ¶m,
|
||||||
@ -543,6 +571,57 @@ class HistMultiEvaluator {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gather the expand entries from all the workers.
|
||||||
|
* @param entries Local expand entries on this worker.
|
||||||
|
* @return Global expand entries gathered from all workers.
|
||||||
|
*/
|
||||||
|
std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
|
||||||
|
auto const world = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
auto const num_entries = entries.size();
|
||||||
|
|
||||||
|
// First, gather all the primitive fields.
|
||||||
|
std::vector<MultiExpandEntry> all_entries(num_entries * world);
|
||||||
|
std::vector<uint32_t> cat_bits;
|
||||||
|
std::vector<std::size_t> cat_bits_sizes;
|
||||||
|
std::vector<GradientPairPrecise> gradients;
|
||||||
|
for (auto i = 0; i < num_entries; i++) {
|
||||||
|
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes,
|
||||||
|
&gradients);
|
||||||
|
}
|
||||||
|
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(MultiExpandEntry));
|
||||||
|
|
||||||
|
// Gather all the cat_bits.
|
||||||
|
auto gathered_cat_bits = collective::AllgatherV(cat_bits, cat_bits_sizes);
|
||||||
|
|
||||||
|
// Gather all the gradients.
|
||||||
|
auto const num_gradients = gradients.size();
|
||||||
|
std::vector<GradientPairPrecise> all_gradients(num_gradients * world);
|
||||||
|
std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank);
|
||||||
|
collective::Allgather(all_gradients.data(), all_gradients.size() * sizeof(GradientPairPrecise));
|
||||||
|
|
||||||
|
auto const total_entries = num_entries * world;
|
||||||
|
auto const gradients_per_entry = num_gradients / num_entries;
|
||||||
|
auto const gradients_per_side = gradients_per_entry / 2;
|
||||||
|
common::ParallelFor(total_entries, ctx_->Threads(), [&] (auto i) {
|
||||||
|
// Copy the cat_bits back into all expand entries.
|
||||||
|
all_entries[i].split.cat_bits.resize(gathered_cat_bits.sizes[i]);
|
||||||
|
std::copy_n(gathered_cat_bits.result.cbegin() + gathered_cat_bits.offsets[i],
|
||||||
|
gathered_cat_bits.sizes[i], all_entries[i].split.cat_bits.begin());
|
||||||
|
|
||||||
|
// Copy the gradients back into all expand entries.
|
||||||
|
all_entries[i].split.left_sum.resize(gradients_per_side);
|
||||||
|
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry, gradients_per_side,
|
||||||
|
all_entries[i].split.left_sum.begin());
|
||||||
|
all_entries[i].split.right_sum.resize(gradients_per_side);
|
||||||
|
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry + gradients_per_side,
|
||||||
|
gradients_per_side, all_entries[i].split.right_sum.begin());
|
||||||
|
});
|
||||||
|
|
||||||
|
return all_entries;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
|
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
|
||||||
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
|
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
|
||||||
@ -597,6 +676,18 @@ class HistMultiEvaluator {
|
|||||||
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (is_col_split_) {
|
||||||
|
// With column-wise data split, we gather the best splits from all the workers and update the
|
||||||
|
// expand entries accordingly.
|
||||||
|
auto all_entries = Allgather(entries);
|
||||||
|
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
|
||||||
|
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||||
|
entries[nidx_in_set].split.Update(
|
||||||
|
all_entries[worker * entries.size() + nidx_in_set].split);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
|
linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
|
||||||
@ -660,7 +751,10 @@ class HistMultiEvaluator {
|
|||||||
|
|
||||||
explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
|
explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
|
||||||
std::shared_ptr<common::ColumnSampler> sampler)
|
std::shared_ptr<common::ColumnSampler> sampler)
|
||||||
: param_{param}, column_sampler_{std::move(sampler)}, ctx_{ctx} {
|
: param_{param},
|
||||||
|
column_sampler_{std::move(sampler)},
|
||||||
|
ctx_{ctx},
|
||||||
|
is_col_split_{info.IsColumnSplit()} {
|
||||||
interaction_constraints_.Configure(*param, info.num_col_);
|
interaction_constraints_.Configure(*param, info.num_col_);
|
||||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||||
param_->colsample_bynode, param_->colsample_bylevel,
|
param_->colsample_bynode, param_->colsample_bylevel,
|
||||||
|
|||||||
@ -70,6 +70,22 @@ struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
|
|||||||
os << "split:\n" << e.split << std::endl;
|
os << "split:\n" << e.split << std::endl;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Copy primitive fields into this, and collect cat_bits into a vector.
|
||||||
|
*
|
||||||
|
* This is used for allgather.
|
||||||
|
*
|
||||||
|
* @param that The other entry to copy from
|
||||||
|
* @param collected_cat_bits The vector to collect cat_bits
|
||||||
|
* @param cat_bits_sizes The sizes of the collected cat_bits
|
||||||
|
*/
|
||||||
|
void CopyAndCollect(CPUExpandEntry const& that, std::vector<uint32_t>* collected_cat_bits,
|
||||||
|
std::vector<std::size_t>* cat_bits_sizes) {
|
||||||
|
nid = that.nid;
|
||||||
|
depth = that.depth;
|
||||||
|
split.CopyAndCollect(that.split, collected_cat_bits, cat_bits_sizes);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
|
struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
|
||||||
@ -119,6 +135,24 @@ struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
|
|||||||
os << "]\n";
|
os << "]\n";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Copy primitive fields into this, and collect cat_bits and gradients into vectors.
|
||||||
|
*
|
||||||
|
* This is used for allgather.
|
||||||
|
*
|
||||||
|
* @param that The other entry to copy from
|
||||||
|
* @param collected_cat_bits The vector to collect cat_bits
|
||||||
|
* @param cat_bits_sizes The sizes of the collected cat_bits
|
||||||
|
* @param collected_gradients The vector to collect gradients
|
||||||
|
*/
|
||||||
|
void CopyAndCollect(MultiExpandEntry const& that, std::vector<uint32_t>* collected_cat_bits,
|
||||||
|
std::vector<std::size_t>* cat_bits_sizes,
|
||||||
|
std::vector<GradientPairPrecise>* collected_gradients) {
|
||||||
|
nid = that.nid;
|
||||||
|
depth = that.depth;
|
||||||
|
split.CopyAndCollect(that.split, collected_cat_bits, cat_bits_sizes, collected_gradients);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
|
|||||||
@ -419,6 +419,60 @@ struct SplitEntryContainer {
|
|||||||
<< "right_sum: " << s.right_sum << std::endl;
|
<< "right_sum: " << s.right_sum << std::endl;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Copy primitive fields into this, and collect cat_bits into a vector.
|
||||||
|
*
|
||||||
|
* This is used for allgather.
|
||||||
|
*
|
||||||
|
* @param that The other entry to copy from
|
||||||
|
* @param collected_cat_bits The vector to collect cat_bits
|
||||||
|
* @param cat_bits_sizes The sizes of the collected cat_bits
|
||||||
|
*/
|
||||||
|
void CopyAndCollect(SplitEntryContainer<GradientT> const &that,
|
||||||
|
std::vector<uint32_t> *collected_cat_bits,
|
||||||
|
std::vector<std::size_t> *cat_bits_sizes) {
|
||||||
|
loss_chg = that.loss_chg;
|
||||||
|
sindex = that.sindex;
|
||||||
|
split_value = that.split_value;
|
||||||
|
is_cat = that.is_cat;
|
||||||
|
static_assert(std::is_trivially_copyable_v<GradientT>);
|
||||||
|
left_sum = that.left_sum;
|
||||||
|
right_sum = that.right_sum;
|
||||||
|
collected_cat_bits->insert(collected_cat_bits->end(), that.cat_bits.cbegin(),
|
||||||
|
that.cat_bits.cend());
|
||||||
|
cat_bits_sizes->emplace_back(that.cat_bits.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Copy primitive fields into this, and collect cat_bits and gradient sums into vectors.
|
||||||
|
*
|
||||||
|
* This is used for allgather.
|
||||||
|
*
|
||||||
|
* @param that The other entry to copy from
|
||||||
|
* @param collected_cat_bits The vector to collect cat_bits
|
||||||
|
* @param cat_bits_sizes The sizes of the collected cat_bits
|
||||||
|
* @param collected_gradients The vector to collect gradients
|
||||||
|
*/
|
||||||
|
template <typename G>
|
||||||
|
void CopyAndCollect(SplitEntryContainer<GradientT> const &that,
|
||||||
|
std::vector<uint32_t> *collected_cat_bits,
|
||||||
|
std::vector<std::size_t> *cat_bits_sizes,
|
||||||
|
std::vector<G> *collected_gradients) {
|
||||||
|
loss_chg = that.loss_chg;
|
||||||
|
sindex = that.sindex;
|
||||||
|
split_value = that.split_value;
|
||||||
|
is_cat = that.is_cat;
|
||||||
|
collected_cat_bits->insert(collected_cat_bits->end(), that.cat_bits.cbegin(),
|
||||||
|
that.cat_bits.cend());
|
||||||
|
cat_bits_sizes->emplace_back(that.cat_bits.size());
|
||||||
|
static_assert(!std::is_trivially_copyable_v<GradientT>);
|
||||||
|
collected_gradients->insert(collected_gradients->end(), that.left_sum.cbegin(),
|
||||||
|
that.left_sum.cend());
|
||||||
|
collected_gradients->insert(collected_gradients->end(), that.right_sum.cbegin(),
|
||||||
|
that.right_sum.cend());
|
||||||
|
}
|
||||||
|
|
||||||
/*!\return feature index to split on */
|
/*!\return feature index to split on */
|
||||||
[[nodiscard]] bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
[[nodiscard]] bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||||
/*!\return whether missing value goes to left branch */
|
/*!\return whether missing value goes to left branch */
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class GloablApproxBuilder {
|
|||||||
protected:
|
protected:
|
||||||
TrainParam const *param_;
|
TrainParam const *param_;
|
||||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||||
HistEvaluator<CPUExpandEntry> evaluator_;
|
HistEvaluator evaluator_;
|
||||||
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
||||||
Context const *ctx_;
|
Context const *ctx_;
|
||||||
ObjInfo const *const task_;
|
ObjInfo const *const task_;
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
#include <utility> // for move, swap
|
#include <utility> // for move, swap
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../collective/aggregator.h" // for GlobalSum
|
||||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
||||||
#include "../collective/communicator.h" // for Operation
|
#include "../collective/communicator.h" // for Operation
|
||||||
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
||||||
@ -200,8 +201,8 @@ class MultiTargetHistBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(root_sum.CContiguous());
|
CHECK(root_sum.CContiguous());
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(root_sum.Values().data()),
|
||||||
reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2);
|
root_sum.Size() * 2);
|
||||||
|
|
||||||
std::vector<MultiExpandEntry> nodes{best};
|
std::vector<MultiExpandEntry> nodes{best};
|
||||||
std::size_t i = 0;
|
std::size_t i = 0;
|
||||||
@ -335,7 +336,7 @@ class HistBuilder {
|
|||||||
common::Monitor *monitor_;
|
common::Monitor *monitor_;
|
||||||
TrainParam const *param_;
|
TrainParam const *param_;
|
||||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||||
std::unique_ptr<HistEvaluator<CPUExpandEntry>> evaluator_;
|
std::unique_ptr<HistEvaluator> evaluator_;
|
||||||
std::vector<CommonRowPartitioner> partitioner_;
|
std::vector<CommonRowPartitioner> partitioner_;
|
||||||
|
|
||||||
// back pointers to tree and data matrix
|
// back pointers to tree and data matrix
|
||||||
@ -354,7 +355,7 @@ class HistBuilder {
|
|||||||
: monitor_{monitor},
|
: monitor_{monitor},
|
||||||
param_{param},
|
param_{param},
|
||||||
col_sampler_{std::move(column_sampler)},
|
col_sampler_{std::move(column_sampler)},
|
||||||
evaluator_{std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx, param, fmat->Info(),
|
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(),
|
||||||
col_sampler_)},
|
col_sampler_)},
|
||||||
p_last_fmat_(fmat),
|
p_last_fmat_(fmat),
|
||||||
histogram_builder_{new HistogramBuilder<CPUExpandEntry>},
|
histogram_builder_{new HistogramBuilder<CPUExpandEntry>},
|
||||||
@ -395,8 +396,7 @@ class HistBuilder {
|
|||||||
}
|
}
|
||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
collective::IsDistributed(), fmat->Info().IsColumnSplit());
|
collective::IsDistributed(), fmat->Info().IsColumnSplit());
|
||||||
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
|
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
|
||||||
col_sampler_);
|
|
||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
@ -455,8 +455,7 @@ class HistBuilder {
|
|||||||
for (auto const &grad : gpair_h) {
|
for (auto const &grad : gpair_h) {
|
||||||
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
||||||
}
|
}
|
||||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&grad_stat),
|
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&grad_stat), 2);
|
||||||
2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||||
|
|||||||
@ -23,6 +23,7 @@
|
|||||||
|
|
||||||
#include "../../src/collective/communicator-inl.h"
|
#include "../../src/collective/communicator-inl.h"
|
||||||
#include "../../src/common/common.h"
|
#include "../../src/common/common.h"
|
||||||
|
#include "../../src/common/threading_utils.h"
|
||||||
#include "../../src/data/array_interface.h"
|
#include "../../src/data/array_interface.h"
|
||||||
#include "filesystem.h" // dmlc::TemporaryDirectory
|
#include "filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
@ -388,6 +389,23 @@ inline Context CreateEmptyGenericParam(int gpu_id) {
|
|||||||
return tparam;
|
return tparam;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::unique_ptr<HostDeviceVector<GradientPair>> GenerateGradients(
|
||||||
|
std::size_t rows, bst_target_t n_targets = 1) {
|
||||||
|
auto p_gradients = std::make_unique<HostDeviceVector<GradientPair>>(rows * n_targets);
|
||||||
|
auto& h_gradients = p_gradients->HostVector();
|
||||||
|
|
||||||
|
xgboost::SimpleLCG gen;
|
||||||
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < rows * n_targets; ++i) {
|
||||||
|
auto grad = dist(&gen);
|
||||||
|
auto hess = dist(&gen);
|
||||||
|
h_gradients[i] = GradientPair{grad, hess};
|
||||||
|
}
|
||||||
|
|
||||||
|
return p_gradients;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Make a context that uses CUDA.
|
* \brief Make a context that uses CUDA.
|
||||||
*/
|
*/
|
||||||
@ -509,11 +527,7 @@ void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&
|
|||||||
xgboost::collective::Finalize();
|
xgboost::collective::Finalize();
|
||||||
};
|
};
|
||||||
#if defined(_OPENMP)
|
#if defined(_OPENMP)
|
||||||
#pragma omp parallel num_threads(world_size)
|
common::ParallelFor(world_size, world_size, run);
|
||||||
{
|
|
||||||
auto rank = omp_get_thread_num();
|
|
||||||
run(rank);
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < world_size; rank++) {
|
for (auto rank = 0; rank < world_size; rank++) {
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
#include "../../../plugin/federated/federated_server.h"
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
#include "../../../src/collective/communicator-inl.h"
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
|
#include "../../../src/common/threading_utils.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -75,11 +76,7 @@ void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_
|
|||||||
xgboost::collective::Finalize();
|
xgboost::collective::Finalize();
|
||||||
};
|
};
|
||||||
#if defined(_OPENMP)
|
#if defined(_OPENMP)
|
||||||
#pragma omp parallel num_threads(world_size)
|
common::ParallelFor(world_size, world_size, run);
|
||||||
{
|
|
||||||
auto rank = omp_get_thread_num();
|
|
||||||
run(rank);
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < world_size; rank++) {
|
for (auto rank = 0; rank < world_size; rank++) {
|
||||||
|
|||||||
@ -15,9 +15,9 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace {
|
namespace {
|
||||||
auto MakeModel(std::string objective, std::shared_ptr<DMatrix> dmat) {
|
auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<DMatrix> dmat) {
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
learner->SetParam("tree_method", "approx");
|
learner->SetParam("tree_method", tree_method);
|
||||||
learner->SetParam("objective", objective);
|
learner->SetParam("objective", objective);
|
||||||
if (objective.find("quantile") != std::string::npos) {
|
if (objective.find("quantile") != std::string::npos) {
|
||||||
learner->SetParam("quantile_alpha", "0.5");
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
@ -35,7 +35,7 @@ auto MakeModel(std::string objective, std::shared_ptr<DMatrix> dmat) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
||||||
std::string objective) {
|
std::string tree_method, std::string objective) {
|
||||||
auto const world_size = collective::GetWorldSize();
|
auto const world_size = collective::GetWorldSize();
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||||
@ -61,7 +61,7 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
|
|||||||
}
|
}
|
||||||
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
auto model = MakeModel(objective, sliced);
|
auto model = MakeModel(tree_method, objective, sliced);
|
||||||
auto base_score = GetBaseScore(model);
|
auto base_score = GetBaseScore(model);
|
||||||
ASSERT_EQ(base_score, expected_base_score);
|
ASSERT_EQ(base_score, expected_base_score);
|
||||||
ASSERT_EQ(model, expected_model);
|
ASSERT_EQ(model, expected_model);
|
||||||
@ -76,7 +76,7 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
|||||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||||
void TearDown() override { server_.reset(nullptr); }
|
void TearDown() override { server_.reset(nullptr); }
|
||||||
|
|
||||||
void Run(std::string objective) {
|
void Run(std::string tree_method, std::string objective) {
|
||||||
static auto constexpr kRows{16};
|
static auto constexpr kRows{16};
|
||||||
static auto constexpr kCols{16};
|
static auto constexpr kCols{16};
|
||||||
|
|
||||||
@ -99,17 +99,22 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto model = MakeModel(objective, dmat);
|
auto model = MakeModel(tree_method, objective, dmat);
|
||||||
auto score = GetBaseScore(model);
|
auto score = GetBaseScore(model);
|
||||||
|
|
||||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
||||||
score, model, objective);
|
score, model, tree_method, objective);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(FederatedLearnerTest, Objective) {
|
TEST_P(FederatedLearnerTest, Approx) {
|
||||||
std::string objective = GetParam();
|
std::string objective = GetParam();
|
||||||
this->Run(objective);
|
this->Run("approx", objective);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(FederatedLearnerTest, Hist) {
|
||||||
|
std::string objective = GetParam();
|
||||||
|
this->Run("hist", objective);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest,
|
INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest,
|
||||||
|
|||||||
@ -33,7 +33,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
|||||||
|
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
common::HistCollection hist;
|
common::HistCollection hist;
|
||||||
std::vector<GradientPair> row_gpairs = {
|
std::vector<GradientPair> row_gpairs = {
|
||||||
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||||
@ -167,7 +167,7 @@ TEST(HistEvaluator, Apply) {
|
|||||||
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
|
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator_ = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
|
|
||||||
CPUExpandEntry entry{0, 0};
|
CPUExpandEntry entry{0, 0};
|
||||||
entry.split.loss_chg = 10.0f;
|
entry.split.loss_chg = 10.0f;
|
||||||
@ -195,7 +195,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
|
|||||||
// check the evaluator is returning the optimal split
|
// check the evaluator is returning the optimal split
|
||||||
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
HistEvaluator<CPUExpandEntry> evaluator{&ctx, ¶m_, info_, sampler};
|
HistEvaluator evaluator{&ctx, ¶m_, info_, sampler};
|
||||||
evaluator.InitRoot(GradStats{total_gpair_});
|
evaluator.InitRoot(GradStats{total_gpair_});
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
@ -225,7 +225,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
|||||||
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||||
|
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|
||||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>(&ctx, {32, param.sparse_threshold})) {
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>(&ctx, {32, param.sparse_threshold})) {
|
||||||
@ -276,7 +276,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
|||||||
info.num_col_ = 1;
|
info.num_col_ = 1;
|
||||||
info.feature_types = {FeatureType::kCategorical};
|
info.feature_types = {FeatureType::kCategorical};
|
||||||
Context ctx;
|
Context ctx;
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m_, info, sampler};
|
auto evaluator = HistEvaluator{&ctx, ¶m_, info, sampler};
|
||||||
evaluator.InitRoot(GradStats{parent_sum_});
|
evaluator.InitRoot(GradStats{parent_sum_});
|
||||||
|
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|||||||
@ -79,7 +79,7 @@ TEST(CPUMonoConstraint, Basic) {
|
|||||||
auto Xy = RandomDataGenerator{kRows, kCols, 0.0}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{kRows, kCols, 0.0}.GenerateDMatrix(true);
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
|
|
||||||
HistEvaluator<CPUExpandEntry> evalutor{&ctx, ¶m, Xy->Info(), sampler};
|
HistEvaluator evalutor{&ctx, ¶m, Xy->Info(), sampler};
|
||||||
evalutor.InitRoot(GradStats{2.0, 2.0});
|
evalutor.InitRoot(GradStats{2.0, 2.0});
|
||||||
|
|
||||||
SplitEntry split;
|
SplitEntry split;
|
||||||
|
|||||||
@ -9,28 +9,20 @@
|
|||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
std::shared_ptr<DMatrix> GenerateDMatrix(std::size_t rows, std::size_t cols){
|
std::shared_ptr<DMatrix> GenerateDMatrix(std::size_t rows, std::size_t cols,
|
||||||
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
bool categorical = false) {
|
||||||
}
|
if (categorical) {
|
||||||
|
std::vector<FeatureType> ft(cols);
|
||||||
std::unique_ptr<HostDeviceVector<GradientPair>> GenerateGradients(std::size_t rows) {
|
for (size_t i = 0; i < ft.size(); ++i) {
|
||||||
auto p_gradients = std::make_unique<HostDeviceVector<GradientPair>>(rows);
|
ft[i] = (i % 3 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
||||||
auto& h_gradients = p_gradients->HostVector();
|
}
|
||||||
|
return RandomDataGenerator(rows, cols, 0.6f).Seed(3).Type(ft).MaxCategory(17).GenerateDMatrix();
|
||||||
xgboost::SimpleLCG gen;
|
} else {
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
for (std::size_t i = 0; i < rows; ++i) {
|
|
||||||
auto grad = dist(&gen);
|
|
||||||
auto hess = dist(&gen);
|
|
||||||
h_gradients[i] = GradientPair{grad, hess};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return p_gradients;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GrowHistMaker, InteractionConstraint)
|
TEST(GrowHistMaker, InteractionConstraint) {
|
||||||
{
|
|
||||||
auto constexpr kRows = 32;
|
auto constexpr kRows = 32;
|
||||||
auto constexpr kCols = 16;
|
auto constexpr kCols = 16;
|
||||||
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||||
@ -74,8 +66,9 @@ TEST(GrowHistMaker, InteractionConstraint)
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TestColumnSplit(int32_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical,
|
||||||
auto p_dmat = GenerateDMatrix(rows, cols);
|
RegTree const& expected_tree) {
|
||||||
|
auto p_dmat = GenerateDMatrix(rows, cols, categorical);
|
||||||
auto p_gradients = GenerateGradients(rows);
|
auto p_gradients = GenerateGradients(rows);
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
@ -90,27 +83,21 @@ void TestColumnSplit(int32_t rows, bst_feature_t cols, RegTree const& expected_t
|
|||||||
param.Init(Args{});
|
param.Init(Args{});
|
||||||
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
Json json{Object{}};
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
tree.SaveModel(&json);
|
||||||
|
Json expected_json{Object{}};
|
||||||
ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0);
|
expected_tree.SaveModel(&expected_json);
|
||||||
ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0);
|
|
||||||
|
|
||||||
FeatureMap fmap;
|
|
||||||
auto json = tree.DumpModel(fmap, false, "json");
|
|
||||||
auto expected_json = expected_tree.DumpModel(fmap, false, "json");
|
|
||||||
ASSERT_EQ(json, expected_json);
|
ASSERT_EQ(json, expected_json);
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
TEST(GrowHistMaker, ColumnSplit) {
|
void TestColumnSplit(bool categorical) {
|
||||||
auto constexpr kRows = 32;
|
auto constexpr kRows = 32;
|
||||||
auto constexpr kCols = 16;
|
auto constexpr kCols = 16;
|
||||||
|
|
||||||
RegTree expected_tree{1u, kCols};
|
RegTree expected_tree{1u, kCols};
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
{
|
{
|
||||||
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
auto p_dmat = GenerateDMatrix(kRows, kCols, categorical);
|
||||||
auto p_gradients = GenerateGradients(kRows);
|
auto p_gradients = GenerateGradients(kRows);
|
||||||
Context ctx;
|
Context ctx;
|
||||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||||
@ -121,6 +108,12 @@ TEST(GrowHistMaker, ColumnSplit) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto constexpr kWorldSize = 2;
|
auto constexpr kWorldSize = 2;
|
||||||
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit, kRows, kCols, std::cref(expected_tree));
|
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, categorical,
|
||||||
|
std::cref(expected_tree));
|
||||||
}
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(GrowHistMaker, ColumnSplitNumerical) { TestColumnSplit(false); }
|
||||||
|
|
||||||
|
TEST(GrowHistMaker, ColumnSplitCategorical) { TestColumnSplit(true); }
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -194,11 +194,65 @@ void TestColumnSplitPartitioner(bst_target_t n_targets) {
|
|||||||
|
|
||||||
auto constexpr kWorkers = 4;
|
auto constexpr kWorkers = 4;
|
||||||
RunWithInMemoryCommunicator(kWorkers, VerifyColumnSplitPartitioner<ExpandEntry>, n_targets,
|
RunWithInMemoryCommunicator(kWorkers, VerifyColumnSplitPartitioner<ExpandEntry>, n_targets,
|
||||||
n_samples, n_features, base_rowid, Xy, min_value, mid_value, mid_partitioner);
|
n_samples, n_features, base_rowid, Xy, min_value, mid_value,
|
||||||
|
mid_partitioner);
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner<CPUExpandEntry>(1); }
|
TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner<CPUExpandEntry>(1); }
|
||||||
|
|
||||||
TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); }
|
TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); }
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
|
||||||
|
RegTree const& expected_tree) {
|
||||||
|
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||||
|
auto p_gradients = GenerateGradients(rows, n_targets);
|
||||||
|
Context ctx;
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
|
||||||
|
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
|
||||||
|
RegTree tree{n_targets, cols};
|
||||||
|
TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||||
|
|
||||||
|
Json json{Object{}};
|
||||||
|
tree.SaveModel(&json);
|
||||||
|
Json expected_json{Object{}};
|
||||||
|
expected_tree.SaveModel(&expected_json);
|
||||||
|
ASSERT_EQ(json, expected_json);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestColumnSplit(bst_target_t n_targets) {
|
||||||
|
auto constexpr kRows = 32;
|
||||||
|
auto constexpr kCols = 16;
|
||||||
|
|
||||||
|
RegTree expected_tree{n_targets, kCols};
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
{
|
||||||
|
auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||||
|
auto p_gradients = GenerateGradients(kRows, n_targets);
|
||||||
|
Context ctx;
|
||||||
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
|
TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
updater->Update(¶m, p_gradients.get(), Xy.get(), position, {&expected_tree});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, n_targets,
|
||||||
|
std::cref(expected_tree));
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(QuantileHist, ColumnSplit) { TestColumnSplit(1); }
|
||||||
|
|
||||||
|
TEST(QuantileHist, ColumnSplitMultiTarget) { TestColumnSplit(3); }
|
||||||
|
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user