More collective aggregators (#9060)

This commit is contained in:
Rong Ou
2023-04-21 12:32:05 -07:00
committed by GitHub
parent 7032981350
commit 8dbe0510de
11 changed files with 107 additions and 89 deletions

View File

@@ -7,6 +7,7 @@
#include <dmlc/registry.h>
#include <array>
#include <memory>
#include <vector>
@@ -211,10 +212,8 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
info.labels_upper_bound_, preds);
double dat[2]{result.Residue(), result.Weights()};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
return Policy::GetFinal(dat[0], dat[1]);
}