Use UBJSON for serializing splits for vertical data split. (#10059)

This commit is contained in:
Jiaming Yuan 2024-02-25 00:18:23 +08:00 committed by GitHub
parent f7005d32c1
commit 0ce4372bd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 162 additions and 165 deletions

View File

@ -174,7 +174,7 @@ jobs:
- uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0
with: with:
submodules: 'true' submodules: 'true'
- uses: actions/setup-python@7f80679172b057fc5e90d70d197929d454754a5a # v4.3.0 - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0
with: with:
python-version: "3.8" python-version: "3.8"
architecture: 'x64' architecture: 'x64'

View File

@ -310,7 +310,7 @@ jobs:
submodules: 'true' submodules: 'true'
- name: Set up Python 3.8 - name: Set up Python 3.8
uses: actions/setup-python@v4 uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0
with: with:
python-version: 3.8 python-version: 3.8

View File

@ -21,7 +21,7 @@ jobs:
with: with:
submodules: 'true' submodules: 'true'
- name: Setup Python - name: Setup Python
uses: actions/setup-python@7f80679172b057fc5e90d70d197929d454754a5a # v4.3.0 uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0
with: with:
python-version: "3.8" python-version: "3.8"
- name: Build wheels - name: Build wheels

View File

@ -74,7 +74,7 @@ jobs:
key: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }} key: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }}
restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }} restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }}
- uses: actions/setup-python@7f80679172b057fc5e90d70d197929d454754a5a # v4.3.0 - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0
with: with:
python-version: "3.8" python-version: "3.8"
architecture: 'x64' architecture: 'x64'

View File

@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \

View File

@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \

View File

@ -0,0 +1,34 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "communicator-inl.h"
namespace xgboost::collective {
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const &vec) { return vec.size(); });
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const &vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
auto out = AllgatherV(collected);
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2022-2023 by XGBoost contributors * Copyright 2022-2024, XGBoost contributors
*/ */
#pragma once #pragma once
#include <string> #include <string>
@ -192,6 +192,18 @@ inline std::vector<T> AllgatherV(std::vector<T> const &input) {
return result; return result;
} }
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input);
/** /**
* @brief Gathers variable-length strings from all processes and distributes them to all processes. * @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings. * @param input Variable-length list of variable-length strings.
@ -294,38 +306,5 @@ template <Operation op>
inline void Allreduce(double *send_receive_buffer, size_t count) { inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op); Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
} }
template <typename T>
struct SpecialAllgatherVResult {
std::vector<std::size_t> offsets;
std::vector<std::size_t> sizes;
std::vector<T> result;
};
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* We assume each worker has the same number of inputs, but each input may be of a different size.
*
* @param inputs All the inputs from the local worker.
* @param sizes Sizes of each input.
*/
template <typename T>
inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs,
std::vector<std::size_t> const &sizes) {
// Gather the sizes across all workers.
auto const all_sizes = Allgather(sizes);
// Calculate input offsets (std::exclusive_scan).
std::vector<std::size_t> offsets(all_sizes.size());
for (std::size_t i = 1; i < offsets.size(); i++) {
offsets[i] = offsets[i - 1] + all_sizes[i - 1];
}
// Gather all the inputs.
auto const all_inputs = AllgatherV(inputs);
return {offsets, all_sizes, all_inputs};
}
} // namespace collective } // namespace collective
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021-2023 by XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
@ -26,6 +26,47 @@
#include "xgboost/linalg.h" // for Constants, Vector #include "xgboost/linalg.h" // for Constants, Vector
namespace xgboost::tree { namespace xgboost::tree {
/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
template <typename ExpandEntry>
std::enable_if_t<std::is_same_v<ExpandEntry, CPUExpandEntry> ||
std::is_same_v<ExpandEntry, MultiExpandEntry>,
std::vector<ExpandEntry>>
AllgatherColumnSplit(std::vector<ExpandEntry> const &entries) {
auto const n_entries = entries.size();
// First, gather all the primitive fields.
std::vector<ExpandEntry> local_entries(n_entries);
// Collect and serialize all entries
std::vector<std::vector<char>> serialized_entries;
for (std::size_t i = 0; i < n_entries; ++i) {
Json jentry{Object{}};
entries[i].Save(&jentry);
std::vector<char> out;
Json::Dump(jentry, &out, std::ios::binary);
serialized_entries.emplace_back(std::move(out));
}
auto all_serialized = collective::VectorAllgatherV(serialized_entries);
CHECK_GE(all_serialized.size(), local_entries.size());
std::vector<ExpandEntry> all_entries(all_serialized.size());
std::transform(all_serialized.cbegin(), all_serialized.cend(), all_entries.begin(),
[](std::vector<char> const &e) {
ExpandEntry entry;
auto je = Json::Load(StringView{e.data(), e.size()}, std::ios::binary);
entry.Load(je);
return entry;
});
return all_entries;
}
class HistEvaluator { class HistEvaluator {
private: private:
struct NodeEntry { struct NodeEntry {
@ -36,8 +77,8 @@ class HistEvaluator {
}; };
private: private:
Context const* ctx_; Context const *ctx_;
TrainParam const* param_; TrainParam const *param_;
std::shared_ptr<common::ColumnSampler> column_sampler_; std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_; TreeEvaluator tree_evaluator_;
bool is_col_split_{false}; bool is_col_split_{false};
@ -202,7 +243,7 @@ class HistEvaluator {
common::CatBitField cat_bits{best.cat_bits}; common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin); bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0); CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) { std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](std::size_t c) {
auto cat = cut_val[c + f_begin]; auto cat = cut_val[c + f_begin];
cat_bits.Set(cat); cat_bits.Set(cat);
}); });
@ -285,57 +326,23 @@ class HistEvaluator {
return left_sum; return left_sum;
} }
/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
auto const world = collective::GetWorldSize();
auto const num_entries = entries.size();
// First, gather all the primitive fields.
std::vector<CPUExpandEntry> local_entries(num_entries);
std::vector<uint32_t> cat_bits;
std::vector<std::size_t> cat_bits_sizes;
for (std::size_t i = 0; i < num_entries; i++) {
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
}
auto all_entries = collective::Allgather(local_entries);
// Gather all the cat_bits.
auto gathered = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
// Copy the cat_bits back into all expand entries.
all_entries[i].split.cat_bits.resize(gathered.sizes[i]);
std::copy_n(gathered.result.cbegin() + gathered.offsets[i], gathered.sizes[i],
all_entries[i].split.cat_bits.begin());
});
return all_entries;
}
public: public:
void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut, void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut,
common::Span<FeatureType const> feature_types, const RegTree &tree, common::Span<FeatureType const> feature_types, const RegTree &tree,
std::vector<CPUExpandEntry> *p_entries) { std::vector<CPUExpandEntry> *p_entries) {
auto n_threads = ctx_->Threads(); auto n_threads = ctx_->Threads();
auto& entries = *p_entries; auto &entries = *p_entries;
// All nodes are on the same level, so we can store the shared ptr. // All nodes are on the same level, so we can store the shared ptr.
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features( std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
entries.size());
for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
auto nidx = entries[nidx_in_set].nid; auto nidx = entries[nidx_in_set].nid;
features[nidx_in_set] = features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
} }
CHECK(!features.empty()); CHECK(!features.empty());
const size_t grain_size = const size_t grain_size = std::max<size_t>(1, features.front()->Size() / n_threads);
std::max<size_t>(1, features.front()->Size() / n_threads); common::BlockedSpace2d space(
common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) { entries.size(), [&](size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
return features[nidx_in_set]->Size(); grain_size);
}, grain_size);
std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size()); std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
for (size_t i = 0; i < entries.size(); ++i) { for (size_t i = 0; i < entries.size(); ++i) {
@ -344,7 +351,7 @@ class HistEvaluator {
} }
} }
auto evaluator = tree_evaluator_.GetEvaluator(); auto evaluator = tree_evaluator_.GetEvaluator();
auto const& cut_ptrs = cut.Ptrs(); auto const &cut_ptrs = cut.Ptrs();
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) { common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num(); auto tidx = omp_get_thread_num();
@ -385,18 +392,16 @@ class HistEvaluator {
} }
}); });
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) { for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update( entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
tloc_candidates[n_threads * nidx_in_set + tidx].split);
} }
} }
if (is_col_split_) { if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the // With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly. // expand entries accordingly.
auto all_entries = Allgather(entries); auto all_entries = AllgatherColumnSplit(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) { for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update( entries[nidx_in_set].split.Update(
@ -407,7 +412,7 @@ class HistEvaluator {
} }
// Add splits to tree, handles all statistic // Add splits to tree, handles all statistic
void ApplyTreeSplit(CPUExpandEntry const& candidate, RegTree *p_tree) { void ApplyTreeSplit(CPUExpandEntry const &candidate, RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator(); auto evaluator = tree_evaluator_.GetEvaluator();
RegTree &tree = *p_tree; RegTree &tree = *p_tree;
@ -437,8 +442,7 @@ class HistEvaluator {
auto left_child = tree[candidate.nid].LeftChild(); auto left_child = tree[candidate.nid].LeftChild();
auto right_child = tree[candidate.nid].RightChild(); auto right_child = tree[candidate.nid].RightChild();
tree_evaluator_.AddSplit(candidate.nid, left_child, right_child, tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
tree[candidate.nid].SplitIndex(), left_weight, tree[candidate.nid].SplitIndex(), left_weight, right_weight);
right_weight);
evaluator = tree_evaluator_.GetEvaluator(); evaluator = tree_evaluator_.GetEvaluator();
snode_.resize(tree.GetNodes().size()); snode_.resize(tree.GetNodes().size());
@ -449,8 +453,7 @@ class HistEvaluator {
snode_.at(right_child).root_gain = snode_.at(right_child).root_gain =
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum}); evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});
interaction_constraints_.Split(candidate.nid, interaction_constraints_.Split(candidate.nid, tree[candidate.nid].SplitIndex(), left_child,
tree[candidate.nid].SplitIndex(), left_child,
right_child); right_child);
} }
@ -571,53 +574,6 @@ class HistMultiEvaluator {
return false; return false;
} }
/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
auto const world = collective::GetWorldSize();
auto const num_entries = entries.size();
// First, gather all the primitive fields.
std::vector<MultiExpandEntry> local_entries(num_entries);
std::vector<uint32_t> cat_bits;
std::vector<std::size_t> cat_bits_sizes;
std::vector<GradientPairPrecise> gradients;
for (std::size_t i = 0; i < num_entries; i++) {
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes, &gradients);
}
auto all_entries = collective::Allgather(local_entries);
// Gather all the cat_bits.
auto gathered_cat_bits = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
// Gather all the gradients.
auto const num_gradients = gradients.size();
auto const all_gradients = collective::Allgather(gradients);
auto const total_entries = num_entries * world;
auto const gradients_per_entry = num_gradients / num_entries;
auto const gradients_per_side = gradients_per_entry / 2;
common::ParallelFor(total_entries, ctx_->Threads(), [&] (auto i) {
// Copy the cat_bits back into all expand entries.
all_entries[i].split.cat_bits.resize(gathered_cat_bits.sizes[i]);
std::copy_n(gathered_cat_bits.result.cbegin() + gathered_cat_bits.offsets[i],
gathered_cat_bits.sizes[i], all_entries[i].split.cat_bits.begin());
// Copy the gradients back into all expand entries.
all_entries[i].split.left_sum.resize(gradients_per_side);
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry, gradients_per_side,
all_entries[i].split.left_sum.begin());
all_entries[i].split.right_sum.resize(gradients_per_side);
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry + gradients_per_side,
gradients_per_side, all_entries[i].split.right_sum.begin());
});
return all_entries;
}
public: public:
void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist, void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist,
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) { common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
@ -676,7 +632,7 @@ class HistMultiEvaluator {
if (is_col_split_) { if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the // With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly. // expand entries accordingly.
auto all_entries = Allgather(entries); auto all_entries = AllgatherColumnSplit(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) { for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update( entries[nidx_in_set].split.Update(

View File

@ -90,7 +90,6 @@ struct ExpandEntryImpl {
} }
self->split.is_cat = get<Boolean const>(split["is_cat"]); self->split.is_cat = get<Boolean const>(split["is_cat"]);
self->LoadGrad(split); self->LoadGrad(split);
} }
}; };
@ -106,8 +105,8 @@ struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
void SaveGrad(Json* p_out) const { void SaveGrad(Json* p_out) const {
auto& out = *p_out; auto& out = *p_out;
auto save = [&](std::string const& name, GradStats const& sum) { auto save = [&](std::string const& name, GradStats const& sum) {
out[name] = F32Array{2}; out[name] = F64Array{2};
auto& array = get<F32Array>(out[name]); auto& array = get<F64Array>(out[name]);
array[0] = sum.GetGrad(); array[0] = sum.GetGrad();
array[1] = sum.GetHess(); array[1] = sum.GetHess();
}; };
@ -115,9 +114,9 @@ struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
save("right_sum", this->split.right_sum); save("right_sum", this->split.right_sum);
} }
void LoadGrad(Json const& in) { void LoadGrad(Json const& in) {
auto const& left_sum = get<F32Array const>(in["left_sum"]); auto const& left_sum = get<F64Array const>(in["left_sum"]);
this->split.left_sum = GradStats{left_sum[0], left_sum[1]}; this->split.left_sum = GradStats{left_sum[0], left_sum[1]};
auto const& right_sum = get<F32Array const>(in["right_sum"]); auto const& right_sum = get<F64Array const>(in["right_sum"]);
this->split.right_sum = GradStats{right_sum[0], right_sum[1]}; this->split.right_sum = GradStats{right_sum[0], right_sum[1]};
} }
@ -173,8 +172,8 @@ struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
void SaveGrad(Json* p_out) const { void SaveGrad(Json* p_out) const {
auto& out = *p_out; auto& out = *p_out;
auto save = [&](std::string const& name, std::vector<GradientPairPrecise> const& sum) { auto save = [&](std::string const& name, std::vector<GradientPairPrecise> const& sum) {
out[name] = F32Array{sum.size() * 2}; out[name] = F64Array{sum.size() * 2};
auto& array = get<F32Array>(out[name]); auto& array = get<F64Array>(out[name]);
for (std::size_t i = 0, j = 0; i < sum.size(); i++, j += 2) { for (std::size_t i = 0, j = 0; i < sum.size(); i++, j += 2) {
array[j] = sum[i].GetGrad(); array[j] = sum[i].GetGrad();
array[j + 1] = sum[i].GetHess(); array[j + 1] = sum[i].GetHess();
@ -185,7 +184,7 @@ struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
} }
void LoadGrad(Json const& in) { void LoadGrad(Json const& in) {
auto load = [&](std::string const& name, std::vector<GradientPairPrecise>* p_sum) { auto load = [&](std::string const& name, std::vector<GradientPairPrecise>* p_sum) {
auto const& array = get<F32Array const>(in[name]); auto const& array = get<F64Array const>(in[name]);
auto& sum = *p_sum; auto& sum = *p_sum;
sum.resize(array.size() / 2); sum.resize(array.size() / 2);
for (std::size_t i = 0, j = 0; i < sum.size(); ++i, j += 2) { for (std::size_t i = 0, j = 0; i < sum.size(); ++i, j += 2) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2017-2023, XGBoost Contributors * Copyright 2017-2024, XGBoost Contributors
* \file updater_quantile_hist.cc * \file updater_quantile_hist.cc
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov * \author Philip Cho, Tianqi Checn, Egor Smirnov
@ -149,9 +149,6 @@ class MultiTargetHistBuilder {
} }
void InitData(DMatrix *p_fmat, RegTree const *p_tree) { void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
if (collective::IsDistributed()) {
LOG(FATAL) << "Distributed training for vector-leaf is not yet supported.";
}
monitor_->Start(__func__); monitor_->Start(__func__);
p_last_fmat_ = p_fmat; p_last_fmat_ = p_fmat;

View File

@ -1,13 +1,12 @@
/*! /**
* Copyright 2022 XGBoost contributors * Copyright 2022-2024, XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../../src/collective/rabit_communicator.h" #include "../../../src/collective/rabit_communicator.h"
#include "../helpers.h"
namespace xgboost { namespace xgboost::collective {
namespace collective {
TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
auto construct = []() { RabitCommunicator comm{0, 0}; }; auto construct = []() { RabitCommunicator comm{0, 0}; };
EXPECT_THROW(construct(), dmlc::Error); EXPECT_THROW(construct(), dmlc::Error);
@ -35,5 +34,37 @@ TEST(RabitCommunicatorSimpleTest, IsNotDistributed) {
EXPECT_FALSE(comm.IsDistributed()); EXPECT_FALSE(comm.IsDistributed());
} }
} // namespace collective namespace {
} // namespace xgboost void VerifyVectorAllgatherV() {
auto n_workers = collective::GetWorldSize();
ASSERT_EQ(n_workers, 3);
auto rank = collective::GetRank();
// Construct input that has different length for each worker.
std::vector<std::vector<char>> inputs;
for (std::int32_t i = 0; i < rank + 1; ++i) {
std::vector<char> in;
for (std::int32_t j = 0; j < rank + 1; ++j) {
in.push_back(static_cast<char>(j));
}
inputs.emplace_back(std::move(in));
}
auto outputs = VectorAllgatherV(inputs);
ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2);
auto const& res = outputs;
for (std::int32_t i = 0; i < n_workers; ++i) {
std::int32_t k = 0;
for (auto v : res[i]) {
ASSERT_EQ(v, k++);
}
}
}
} // namespace
TEST(VectorAllgatherV, Basic) {
std::int32_t n_workers{3};
RunWithInMemoryCommunicator(n_workers, VerifyVectorAllgatherV);
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2023, XGBoost Contributors * Copyright 2019-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2018-2023 by XGBoost Contributors * Copyright 2018-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
@ -18,7 +18,6 @@
#include "xgboost/data.h" #include "xgboost/data.h"
namespace xgboost::tree { namespace xgboost::tree {
namespace { namespace {
template <typename ExpandEntry> template <typename ExpandEntry>
void TestPartitioner(bst_target_t n_targets) { void TestPartitioner(bst_target_t n_targets) {
@ -253,5 +252,5 @@ void TestColumnSplit(bst_target_t n_targets) {
TEST(QuantileHist, ColumnSplit) { TestColumnSplit(1); } TEST(QuantileHist, ColumnSplit) { TestColumnSplit(1); }
TEST(QuantileHist, DISABLED_ColumnSplitMultiTarget) { TestColumnSplit(3); } TEST(QuantileHist, ColumnSplitMultiTarget) { TestColumnSplit(3); }
} // namespace xgboost::tree } // namespace xgboost::tree