More collective aggregators (#9060)

This commit is contained in:
Rong Ou
2023-04-21 12:32:05 -07:00
committed by GitHub
parent 7032981350
commit 8dbe0510de
11 changed files with 107 additions and 89 deletions

View File

@@ -116,10 +116,7 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI
// we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class.
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
results.Values().size());
}
collective::GlobalSum(info, &results.Values());
double auc_sum{0};
double tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
@@ -293,17 +290,8 @@ class EvalAUC : public MetricNoCache {
InvalidGroupAUC();
}
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
}
auc = results[0];
valid_groups = static_cast<uint32_t>(results[1]);
if (valid_groups <= 0) {
auc = std::numeric_limits<double>::quiet_NaN();
} else {
auc /= valid_groups;
auc = collective::GlobalRatio(info, auc, static_cast<double>(valid_groups));
if (!std::isnan(auc)) {
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
<< ", valid groups: " << valid_groups;
}
@@ -323,19 +311,9 @@ class EvalAUC : public MetricNoCache {
std::tie(fp, tp, auc) =
static_cast<Curve *>(this)->EvalBinary(preds, info);
}
double local_area = fp * tp;
std::array<double, 2> result{auc, local_area};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
}
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
if (local_area <= 0) {
// the dataset across all workers have only positive or negative sample
auc = std::numeric_limits<double>::quiet_NaN();
} else {
CHECK_LE(auc, local_area);
// normalization
auc = auc / local_area;
auc = collective::GlobalRatio(info, auc, fp * tp);
if (!std::isnan(auc)) {
CHECK_LE(auc, 1.0);
}
}
if (std::isnan(auc)) {

View File

@@ -8,6 +8,7 @@
*/
#include <dmlc/registry.h>
#include <array>
#include <cmath>
#include "../collective/communicator-inl.h"
@@ -197,10 +198,8 @@ class PseudoErrorLoss : public MetricNoCache {
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
return std::make_tuple(v, wt);
});
double dat[2]{result.Residue(), result.Weights()};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
}
};
@@ -366,10 +365,8 @@ struct EvalEWiseBase : public MetricNoCache {
return std::make_tuple(residue, wt);
});
double dat[2]{result.Residue(), result.Weights()};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
return Policy::GetFinal(dat[0], dat[1]);
}
@@ -440,10 +437,8 @@ class QuantileError : public MetricNoCache {
CHECK(!alpha_.Empty());
if (info.num_row_ == 0) {
// empty DMatrix on distributed env
double dat[2]{0.0, 0.0};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
std::array<double, 2> dat{0.0, 0.0};
collective::GlobalSum(info, &dat);
CHECK_GT(dat[1], 0);
return dat[0] / dat[1];
}
@@ -480,10 +475,8 @@ class QuantileError : public MetricNoCache {
loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w;
return std::make_tuple(l, w);
});
double dat[2]{result.Residue(), result.Weights()};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
CHECK_GT(dat[1], 0);
return dat[0] / dat[1];
}

View File

@@ -6,6 +6,7 @@
*/
#include <xgboost/metric.h>
#include <array>
#include <atomic>
#include <cmath>
@@ -169,7 +170,7 @@ struct EvalMClassBase : public MetricNoCache {
} else {
CHECK(preds.Size() % info.labels.Size() == 0) << "label and prediction size not match";
}
double dat[2] { 0.0, 0.0 };
std::array<double, 2> dat{0.0, 0.0};
if (info.labels.Size() != 0) {
const size_t nclass = preds.Size() / info.labels.Size();
CHECK_GE(nclass, 1U)
@@ -181,9 +182,7 @@ struct EvalMClassBase : public MetricNoCache {
dat[0] = result.Residue();
dat[1] = result.Weights();
}
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
collective::GlobalSum(info, &dat);
return Derived::GetFinal(dat[0], dat[1]);
}
/*!

View File

@@ -238,14 +238,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
exc.Rethrow();
}
if (collective::IsDistributed() && info.IsRowSplit()) {
double dat[2]{sum_metric, static_cast<double>(ngroups)};
// approximately estimate the metric using mean
collective::Allreduce<collective::Operation::kSum>(dat, 2);
return dat[0] / dat[1];
} else {
return sum_metric / ngroups;
}
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
}
const char* Name() const override {
@@ -401,9 +394,8 @@ class EvalRankWithCache : public Metric {
namespace {
double Finalize(MetaInfo const& info, double score, double sw) {
std::array<double, 2> dat{score, sw};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat.data(), dat.size());
}
collective::GlobalSum(info, &dat);
std::tie(score, sw) = std::tuple_cat(dat);
if (sw > 0.0) {
score = score / sw;
}

View File

@@ -7,6 +7,7 @@
#include <dmlc/registry.h>
#include <array>
#include <memory>
#include <vector>
@@ -211,10 +212,8 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
info.labels_upper_bound_, preds);
double dat[2]{result.Residue(), result.Weights()};
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(dat, 2);
}
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
return Policy::GetFinal(dat[0], dat[1]);
}