Support column split in multi-target hist (#9171)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
@@ -224,5 +225,46 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AllgatherVResult {
|
||||
std::vector<std::size_t> offsets;
|
||||
std::vector<std::size_t> sizes;
|
||||
std::vector<T> result;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
*
|
||||
* We assume each worker has the same number of inputs, but each input may be of a different size.
|
||||
*
|
||||
* @param inputs All the inputs from the local worker.
|
||||
* @param sizes Sizes of each input.
|
||||
*/
|
||||
template <typename T>
|
||||
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
|
||||
std::vector<std::size_t> const &sizes) {
|
||||
auto num_inputs = sizes.size();
|
||||
|
||||
// Gather the sizes across all workers.
|
||||
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
|
||||
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
|
||||
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
|
||||
|
||||
// Calculate input offsets (std::exclusive_scan).
|
||||
std::vector<std::size_t> offsets(all_sizes.size());
|
||||
for (auto i = 1; i < offsets.size(); i++) {
|
||||
offsets[i] = offsets[i - 1] + all_sizes[i - 1];
|
||||
}
|
||||
|
||||
// Gather all the inputs.
|
||||
auto total_input_size = offsets.back() + all_sizes.back();
|
||||
std::vector<T> all_inputs(total_input_size);
|
||||
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
|
||||
// We cannot use allgather here, since each worker might have a different size.
|
||||
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
|
||||
|
||||
return {offsets, all_sizes, all_inputs};
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user