Deterministic result for element-wise/mclass metrics. (#7303)

Remove openmp reduction.
This commit is contained in:
Jiaming Yuan 2021-10-13 14:22:40 +08:00 committed by GitHub
parent 406c70ba0e
commit 4ddf8d001c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 139 additions and 45 deletions

View File

@ -14,6 +14,7 @@
#include "metric_common.h"
#include "../common/math.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
@ -34,29 +35,29 @@ class ElementWiseMetricsReduction {
public:
explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {}
PackedReduceResult CpuReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) const {
PackedReduceResult
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
const HostDeviceVector<bst_float> &labels,
const HostDeviceVector<bst_float> &preds,
int32_t n_threads) const {
size_t ndata = labels.Size();
const auto& h_labels = labels.HostVector();
const auto& h_weights = weights.HostVector();
const auto& h_preds = preds.HostVector();
bst_float residue_sum = 0;
bst_float weights_sum = 0;
std::vector<double> score_tloc(n_threads, 0.0);
std::vector<double> weight_tloc(n_threads, 0.0);
dmlc::OMPException exc;
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) {
exc.Run([&]() {
const bst_float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
residue_sum += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
weights_sum += wt;
common::ParallelFor(ndata, n_threads, [&](size_t i) {
float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
auto t_idx = omp_get_thread_num();
score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
weight_tloc[t_idx] += wt;
});
}
exc.Rethrow();
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
PackedReduceResult res { residue_sum, weights_sum };
return res;
}
@ -100,19 +101,19 @@ class ElementWiseMetricsReduction {
#endif // XGBOOST_USE_CUDA
PackedReduceResult Reduce(
const GenericParameter &tparam,
int device,
const GenericParameter &ctx,
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result;
if (device < 0) {
result = CpuReduceMetrics(weights, labels, preds);
if (ctx.gpu_id < 0) {
auto n_threads = ctx.Threads();
result = CpuReduceMetrics(weights, labels, preds, n_threads);
}
#if defined(XGBOOST_USE_CUDA)
else { // NOLINT
device_ = device;
device_ = ctx.gpu_id;
preds.SetDevice(device_);
labels.SetDevice(device_);
weights.SetDevice(device_);
@ -365,10 +366,7 @@ struct EvalEWiseBase : public Metric {
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification";
int device = tparam_->gpu_id;
auto result =
reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds);
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_, preds);
double dat[2] { result.Residue(), result.Weights() };

View File

@ -6,11 +6,14 @@
*/
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <atomic>
#include <cmath>
#include "metric_common.h"
#include "../common/math.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
@ -37,38 +40,41 @@ class MultiClassMetricsReduction {
public:
MultiClassMetricsReduction() = default;
PackedReduceResult CpuReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds,
const size_t n_class) const {
PackedReduceResult
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
const HostDeviceVector<bst_float> &labels,
const HostDeviceVector<bst_float> &preds,
const size_t n_class, int32_t n_threads) const {
size_t ndata = labels.Size();
const auto& h_labels = labels.HostVector();
const auto& h_weights = weights.HostVector();
const auto& h_preds = preds.HostVector();
bst_float residue_sum = 0;
bst_float weights_sum = 0;
int label_error = 0;
std::atomic<int> label_error {0};
bool const is_null_weight = weights.Size() == 0;
dmlc::OMPException exc;
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
for (omp_ulong idx = 0; idx < ndata; ++idx) {
exc.Run([&]() {
std::vector<double> scores_tloc(n_threads, 0);
std::vector<double> weights_tloc(n_threads, 0);
common::ParallelFor(ndata, n_threads, [&](size_t idx) {
bst_float weight = is_null_weight ? 1.0f : h_weights[idx];
auto label = static_cast<int>(h_labels[idx]);
if (label >= 0 && label < static_cast<int>(n_class)) {
residue_sum += EvalRowPolicy::EvalRow(
label, h_preds.data() + idx * n_class, n_class) * weight;
weights_sum += weight;
auto t_idx = omp_get_thread_num();
scores_tloc[t_idx] +=
EvalRowPolicy::EvalRow(label, h_preds.data() + idx * n_class,
n_class) *
weight;
weights_tloc[t_idx] += weight;
} else {
label_error = label;
}
});
}
exc.Rethrow();
double residue_sum =
std::accumulate(scores_tloc.cbegin(), scores_tloc.cend(), 0.0);
double weights_sum =
std::accumulate(weights_tloc.cbegin(), weights_tloc.cend(), 0.0);
CheckLabelError(label_error, n_class);
PackedReduceResult res { residue_sum, weights_sum };
@ -131,7 +137,8 @@ class MultiClassMetricsReduction {
PackedReduceResult result;
if (device < 0) {
result = CpuReduceMetrics(weights, labels, preds, n_class);
result =
CpuReduceMetrics(weights, labels, preds, n_class, tparam.Threads());
}
#if defined(XGBOOST_USE_CUDA)
else { // NOLINT

View File

@ -2,9 +2,44 @@
* Copyright 2018-2019 XGBoost contributors
*/
#include <xgboost/metric.h>
#include <xgboost/json.h>
#include <map>
#include <memory>
#include "../helpers.h"
namespace xgboost {
namespace {
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
auto lparam = CreateEmptyGenericParam(device);
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)};
HostDeviceVector<float> predts;
MetaInfo info;
auto &h_labels = info.labels_.HostVector();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
size_t n_samples = 2048;
h_labels.resize(n_samples);
h_predts.resize(n_samples);
for (size_t i = 0; i < n_samples; ++i) {
h_predts[i] = dist(&lcg);
h_labels[i] = dist(&lcg);
}
auto result = metric->Eval(predts, info, false);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info, false), result);
}
}
} // anonymous namespace
} // namespace xgboost
TEST(Metric, DeclareUnifiedTest(RMSE)) {
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam);
@ -26,6 +61,8 @@ TEST(Metric, DeclareUnifiedTest(RMSE)) {
{ 1, 2, 9, 8}),
0.6708f, 0.001f);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"rmse"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(RMSLE)) {
@ -49,6 +86,8 @@ TEST(Metric, DeclareUnifiedTest(RMSLE)) {
{ 0, 1, 2, 9, 8}),
0.2415f, 1e-4);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"rmsle"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(MAE)) {
@ -72,6 +111,8 @@ TEST(Metric, DeclareUnifiedTest(MAE)) {
{ 1, 2, 9, 8}),
0.54f, 0.001f);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mae"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(MAPE)) {
@ -95,6 +136,8 @@ TEST(Metric, DeclareUnifiedTest(MAPE)) {
{ 1, 2, 9, 8}),
1.3250f, 0.001f);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mape"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(MPHE)) {
@ -118,6 +161,8 @@ TEST(Metric, DeclareUnifiedTest(MPHE)) {
{ 1, 2, 9, 8}),
0.1922f, 1e-4);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(LogLoss)) {
@ -145,6 +190,8 @@ TEST(Metric, DeclareUnifiedTest(LogLoss)) {
{ 1, 2, 9, 8}),
1.3138f, 0.001f);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"logloss"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(Error)) {
@ -197,6 +244,8 @@ TEST(Metric, DeclareUnifiedTest(Error)) {
{ 1, 2, 9, 8}),
0.45f, 0.001f);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"error@0.5"}, GPUIDX);
}
TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
@ -224,4 +273,6 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
{ 1, 2, 9, 8}),
1.5783f, 0.001f);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX);
}

View File

@ -4,6 +4,43 @@
#include "../helpers.h"
namespace xgboost {
inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) {
auto lparam = CreateEmptyGenericParam(device);
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)};
HostDeviceVector<float> predts;
MetaInfo info;
auto &h_labels = info.labels_.HostVector();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
size_t n_samples = 2048, n_classes = 4;
h_labels.resize(n_samples);
h_predts.resize(n_samples * n_classes);
{
SimpleRealUniformDistribution<float> dist{0.0f, static_cast<float>(n_classes)};
for (size_t i = 0; i < n_samples; ++i) {
h_labels[i] = dist(&lcg);
}
}
{
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
for (size_t i = 0; i < n_samples * n_classes; ++i) {
h_predts[i] = dist(&lcg);
}
}
auto result = metric->Eval(predts, info, false);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info, false), result);
}
}
} // namespace xgboost
inline void TestMultiClassError(int device) {
auto lparam = xgboost::CreateEmptyGenericParam(device);
lparam.gpu_id = device;
@ -17,12 +54,12 @@ inline void TestMultiClassError(int device) {
{0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f},
{0, 1, 2}),
0.666f, 0.001f);
delete metric;
}
TEST(Metric, DeclareUnifiedTest(MultiClassError)) {
TestMultiClassError(GPUIDX);
xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"merror"}, GPUIDX);
}
inline void TestMultiClassLogLoss(int device) {
@ -44,6 +81,7 @@ inline void TestMultiClassLogLoss(int device) {
TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) {
TestMultiClassLogLoss(GPUIDX);
xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"mlogloss"}, GPUIDX);
}
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)