Improve allgather functions (#9649)

This commit is contained in:
Rong Ou
2023-10-12 08:31:43 -07:00
committed by GitHub
parent d1dee4ad99
commit e164d51c43
20 changed files with 346 additions and 122 deletions

View File

@@ -76,10 +76,8 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsColumnSplit()) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul);
auto const cols = collective::Allgather(info_.num_col_);
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
if (offset == 0) {
return;
}