Make sure metrics work with column-wise distributed training (#9020)
This commit is contained in:
@@ -116,8 +116,10 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI
|
||||
|
||||
// we have 2 averages going in here, first is among workers, second is among
|
||||
// classes. allreduce sums up fp/tp auc for each class.
|
||||
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
|
||||
results.Values().size());
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
|
||||
results.Values().size());
|
||||
}
|
||||
double auc_sum{0};
|
||||
double tp_sum{0};
|
||||
for (size_t c = 0; c < n_classes; ++c) {
|
||||
@@ -290,7 +292,9 @@ class EvalAUC : public MetricNoCache {
|
||||
}
|
||||
|
||||
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
|
||||
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
|
||||
}
|
||||
auc = results[0];
|
||||
valid_groups = static_cast<uint32_t>(results[1]);
|
||||
|
||||
@@ -319,7 +323,9 @@ class EvalAUC : public MetricNoCache {
|
||||
}
|
||||
double local_area = fp * tp;
|
||||
std::array<double, 2> result{auc, local_area};
|
||||
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
|
||||
}
|
||||
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
|
||||
if (local_area <= 0) {
|
||||
// the dataset across all workers have only positive or negative sample
|
||||
|
||||
@@ -198,7 +198,7 @@ class PseudoErrorLoss : public MetricNoCache {
|
||||
return std::make_tuple(v, wt);
|
||||
});
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
if (collective::IsDistributed()) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
|
||||
@@ -367,7 +367,9 @@ struct EvalEWiseBase : public MetricNoCache {
|
||||
});
|
||||
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
@@ -439,7 +441,9 @@ class QuantileError : public MetricNoCache {
|
||||
if (info.num_row_ == 0) {
|
||||
// empty DMatrix on distributed env
|
||||
double dat[2]{0.0, 0.0};
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
CHECK_GT(dat[1], 0);
|
||||
return dat[0] / dat[1];
|
||||
}
|
||||
@@ -477,7 +481,9 @@ class QuantileError : public MetricNoCache {
|
||||
return std::make_tuple(l, w);
|
||||
});
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
CHECK_GT(dat[1], 0);
|
||||
return dat[0] / dat[1];
|
||||
}
|
||||
|
||||
@@ -181,7 +181,9 @@ struct EvalMClassBase : public MetricNoCache {
|
||||
dat[0] = result.Residue();
|
||||
dat[1] = result.Weights();
|
||||
}
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
|
||||
@@ -244,7 +244,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
if (collective::IsDistributed()) {
|
||||
if (collective::IsDistributed() && info.IsRowSplit()) {
|
||||
double dat[2]{sum_metric, static_cast<double>(ngroups)};
|
||||
// approximately estimate the metric using mean
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
@@ -401,9 +401,11 @@ class EvalRankWithCache : public Metric {
|
||||
};
|
||||
|
||||
namespace {
|
||||
double Finalize(double score, double sw) {
|
||||
double Finalize(MetaInfo const& info, double score, double sw) {
|
||||
std::array<double, 2> dat{score, sw};
|
||||
collective::Allreduce<collective::Operation::kSum>(dat.data(), dat.size());
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat.data(), dat.size());
|
||||
}
|
||||
if (sw > 0.0) {
|
||||
score = score / sw;
|
||||
}
|
||||
@@ -430,7 +432,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache) override {
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
|
||||
return Finalize(ndcg.Residue(), ndcg.Weights());
|
||||
return Finalize(info, ndcg.Residue(), ndcg.Weights());
|
||||
}
|
||||
|
||||
// group local ndcg
|
||||
@@ -476,7 +478,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
|
||||
}
|
||||
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
|
||||
return Finalize(ndcg, sum_w);
|
||||
return Finalize(info, ndcg, sum_w);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -489,7 +491,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
std::shared_ptr<ltr::MAPCache> p_cache) override {
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
|
||||
return Finalize(map.Residue(), map.Weights());
|
||||
return Finalize(info, map.Residue(), map.Weights());
|
||||
}
|
||||
|
||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||
@@ -532,7 +534,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
sw += weight[i];
|
||||
}
|
||||
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
|
||||
return Finalize(sum, sw);
|
||||
return Finalize(info, sum, sw);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -212,7 +212,9 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
|
||||
info.labels_upper_bound_, preds);
|
||||
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||
}
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user