Improve allgather functions (#9649)
This commit is contained in:
@@ -292,20 +292,19 @@ class HistEvaluator {
|
||||
*/
|
||||
std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const num_entries = entries.size();
|
||||
|
||||
// First, gather all the primitive fields.
|
||||
std::vector<CPUExpandEntry> all_entries(num_entries * world);
|
||||
std::vector<CPUExpandEntry> local_entries(num_entries);
|
||||
std::vector<uint32_t> cat_bits;
|
||||
std::vector<std::size_t> cat_bits_sizes;
|
||||
for (std::size_t i = 0; i < num_entries; i++) {
|
||||
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
|
||||
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
|
||||
}
|
||||
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(CPUExpandEntry));
|
||||
auto all_entries = collective::Allgather(local_entries);
|
||||
|
||||
// Gather all the cat_bits.
|
||||
auto gathered = collective::AllgatherV(cat_bits, cat_bits_sizes);
|
||||
auto gathered = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
|
||||
|
||||
common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
|
||||
// Copy the cat_bits back into all expand entries.
|
||||
@@ -579,28 +578,24 @@ class HistMultiEvaluator {
|
||||
*/
|
||||
std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const num_entries = entries.size();
|
||||
|
||||
// First, gather all the primitive fields.
|
||||
std::vector<MultiExpandEntry> all_entries(num_entries * world);
|
||||
std::vector<MultiExpandEntry> local_entries(num_entries);
|
||||
std::vector<uint32_t> cat_bits;
|
||||
std::vector<std::size_t> cat_bits_sizes;
|
||||
std::vector<GradientPairPrecise> gradients;
|
||||
for (std::size_t i = 0; i < num_entries; i++) {
|
||||
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes,
|
||||
&gradients);
|
||||
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes, &gradients);
|
||||
}
|
||||
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(MultiExpandEntry));
|
||||
auto all_entries = collective::Allgather(local_entries);
|
||||
|
||||
// Gather all the cat_bits.
|
||||
auto gathered_cat_bits = collective::AllgatherV(cat_bits, cat_bits_sizes);
|
||||
auto gathered_cat_bits = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
|
||||
|
||||
// Gather all the gradients.
|
||||
auto const num_gradients = gradients.size();
|
||||
std::vector<GradientPairPrecise> all_gradients(num_gradients * world);
|
||||
std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank);
|
||||
collective::Allgather(all_gradients.data(), all_gradients.size() * sizeof(GradientPairPrecise));
|
||||
auto const all_gradients = collective::Allgather(gradients);
|
||||
|
||||
auto const total_entries = num_entries * world;
|
||||
auto const gradients_per_entry = num_gradients / num_entries;
|
||||
|
||||
Reference in New Issue
Block a user