Deterministic result for element-wise/mclass metrics. (#7303)
Remove openmp reduction.
This commit is contained in:
parent
406c70ba0e
commit
4ddf8d001c
@ -14,6 +14,7 @@
|
||||
#include "metric_common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
@ -34,29 +35,29 @@ class ElementWiseMetricsReduction {
|
||||
public:
|
||||
explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {}
|
||||
|
||||
PackedReduceResult CpuReduceMetrics(
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) const {
|
||||
PackedReduceResult
|
||||
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
|
||||
const HostDeviceVector<bst_float> &labels,
|
||||
const HostDeviceVector<bst_float> &preds,
|
||||
int32_t n_threads) const {
|
||||
size_t ndata = labels.Size();
|
||||
|
||||
const auto& h_labels = labels.HostVector();
|
||||
const auto& h_weights = weights.HostVector();
|
||||
const auto& h_preds = preds.HostVector();
|
||||
|
||||
bst_float residue_sum = 0;
|
||||
bst_float weights_sum = 0;
|
||||
std::vector<double> score_tloc(n_threads, 0.0);
|
||||
std::vector<double> weight_tloc(n_threads, 0.0);
|
||||
|
||||
common::ParallelFor(ndata, n_threads, [&](size_t i) {
|
||||
float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
|
||||
auto t_idx = omp_get_thread_num();
|
||||
score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
|
||||
weight_tloc[t_idx] += wt;
|
||||
});
|
||||
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
|
||||
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
exc.Run([&]() {
|
||||
const bst_float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
|
||||
residue_sum += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
|
||||
weights_sum += wt;
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
PackedReduceResult res { residue_sum, weights_sum };
|
||||
return res;
|
||||
}
|
||||
@ -100,19 +101,19 @@ class ElementWiseMetricsReduction {
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
PackedReduceResult Reduce(
|
||||
const GenericParameter &tparam,
|
||||
int device,
|
||||
const GenericParameter &ctx,
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
PackedReduceResult result;
|
||||
|
||||
if (device < 0) {
|
||||
result = CpuReduceMetrics(weights, labels, preds);
|
||||
if (ctx.gpu_id < 0) {
|
||||
auto n_threads = ctx.Threads();
|
||||
result = CpuReduceMetrics(weights, labels, preds, n_threads);
|
||||
}
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
else { // NOLINT
|
||||
device_ = device;
|
||||
device_ = ctx.gpu_id;
|
||||
preds.SetDevice(device_);
|
||||
labels.SetDevice(device_);
|
||||
weights.SetDevice(device_);
|
||||
@ -365,10 +366,7 @@ struct EvalEWiseBase : public Metric {
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "label and prediction size not match, "
|
||||
<< "hint: use merror or mlogloss for multi-class classification";
|
||||
int device = tparam_->gpu_id;
|
||||
|
||||
auto result =
|
||||
reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds);
|
||||
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_, preds);
|
||||
|
||||
double dat[2] { result.Residue(), result.Weights() };
|
||||
|
||||
|
||||
@ -6,11 +6,14 @@
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
@ -37,38 +40,41 @@ class MultiClassMetricsReduction {
|
||||
public:
|
||||
MultiClassMetricsReduction() = default;
|
||||
|
||||
PackedReduceResult CpuReduceMetrics(
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds,
|
||||
const size_t n_class) const {
|
||||
PackedReduceResult
|
||||
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
|
||||
const HostDeviceVector<bst_float> &labels,
|
||||
const HostDeviceVector<bst_float> &preds,
|
||||
const size_t n_class, int32_t n_threads) const {
|
||||
size_t ndata = labels.Size();
|
||||
|
||||
const auto& h_labels = labels.HostVector();
|
||||
const auto& h_weights = weights.HostVector();
|
||||
const auto& h_preds = preds.HostVector();
|
||||
|
||||
bst_float residue_sum = 0;
|
||||
bst_float weights_sum = 0;
|
||||
int label_error = 0;
|
||||
std::atomic<int> label_error {0};
|
||||
bool const is_null_weight = weights.Size() == 0;
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
|
||||
for (omp_ulong idx = 0; idx < ndata; ++idx) {
|
||||
exc.Run([&]() {
|
||||
std::vector<double> scores_tloc(n_threads, 0);
|
||||
std::vector<double> weights_tloc(n_threads, 0);
|
||||
common::ParallelFor(ndata, n_threads, [&](size_t idx) {
|
||||
bst_float weight = is_null_weight ? 1.0f : h_weights[idx];
|
||||
auto label = static_cast<int>(h_labels[idx]);
|
||||
if (label >= 0 && label < static_cast<int>(n_class)) {
|
||||
residue_sum += EvalRowPolicy::EvalRow(
|
||||
label, h_preds.data() + idx * n_class, n_class) * weight;
|
||||
weights_sum += weight;
|
||||
auto t_idx = omp_get_thread_num();
|
||||
scores_tloc[t_idx] +=
|
||||
EvalRowPolicy::EvalRow(label, h_preds.data() + idx * n_class,
|
||||
n_class) *
|
||||
weight;
|
||||
weights_tloc[t_idx] += weight;
|
||||
} else {
|
||||
label_error = label;
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
});
|
||||
|
||||
double residue_sum =
|
||||
std::accumulate(scores_tloc.cbegin(), scores_tloc.cend(), 0.0);
|
||||
double weights_sum =
|
||||
std::accumulate(weights_tloc.cbegin(), weights_tloc.cend(), 0.0);
|
||||
|
||||
CheckLabelError(label_error, n_class);
|
||||
PackedReduceResult res { residue_sum, weights_sum };
|
||||
@ -131,7 +137,8 @@ class MultiClassMetricsReduction {
|
||||
PackedReduceResult result;
|
||||
|
||||
if (device < 0) {
|
||||
result = CpuReduceMetrics(weights, labels, preds, n_class);
|
||||
result =
|
||||
CpuReduceMetrics(weights, labels, preds, n_class, tparam.Threads());
|
||||
}
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
else { // NOLINT
|
||||
|
||||
@ -2,9 +2,44 @@
|
||||
* Copyright 2018-2019 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
|
||||
auto lparam = CreateEmptyGenericParam(device);
|
||||
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)};
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
MetaInfo info;
|
||||
auto &h_labels = info.labels_.HostVector();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
||||
|
||||
size_t n_samples = 2048;
|
||||
h_labels.resize(n_samples);
|
||||
h_predts.resize(n_samples);
|
||||
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
h_predts[i] = dist(&lcg);
|
||||
h_labels[i] = dist(&lcg);
|
||||
}
|
||||
|
||||
auto result = metric->Eval(predts, info, false);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Eval(predts, info, false), result);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
} // namespace xgboost
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(RMSE)) {
|
||||
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam);
|
||||
@ -26,6 +61,8 @@ TEST(Metric, DeclareUnifiedTest(RMSE)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
0.6708f, 0.001f);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"rmse"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(RMSLE)) {
|
||||
@ -49,6 +86,8 @@ TEST(Metric, DeclareUnifiedTest(RMSLE)) {
|
||||
{ 0, 1, 2, 9, 8}),
|
||||
0.2415f, 1e-4);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"rmsle"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MAE)) {
|
||||
@ -72,6 +111,8 @@ TEST(Metric, DeclareUnifiedTest(MAE)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
0.54f, 0.001f);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mae"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MAPE)) {
|
||||
@ -95,6 +136,8 @@ TEST(Metric, DeclareUnifiedTest(MAPE)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
1.3250f, 0.001f);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mape"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MPHE)) {
|
||||
@ -118,6 +161,8 @@ TEST(Metric, DeclareUnifiedTest(MPHE)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
0.1922f, 1e-4);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(LogLoss)) {
|
||||
@ -145,6 +190,8 @@ TEST(Metric, DeclareUnifiedTest(LogLoss)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
1.3138f, 0.001f);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"logloss"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(Error)) {
|
||||
@ -197,6 +244,8 @@ TEST(Metric, DeclareUnifiedTest(Error)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
0.45f, 0.001f);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"error@0.5"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
|
||||
@ -224,4 +273,6 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
|
||||
{ 1, 2, 9, 8}),
|
||||
1.5783f, 0.001f);
|
||||
delete metric;
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX);
|
||||
}
|
||||
|
||||
@ -4,6 +4,43 @@
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) {
|
||||
auto lparam = CreateEmptyGenericParam(device);
|
||||
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)};
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
MetaInfo info;
|
||||
auto &h_labels = info.labels_.HostVector();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
|
||||
size_t n_samples = 2048, n_classes = 4;
|
||||
h_labels.resize(n_samples);
|
||||
h_predts.resize(n_samples * n_classes);
|
||||
|
||||
{
|
||||
SimpleRealUniformDistribution<float> dist{0.0f, static_cast<float>(n_classes)};
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
h_labels[i] = dist(&lcg);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
||||
for (size_t i = 0; i < n_samples * n_classes; ++i) {
|
||||
h_predts[i] = dist(&lcg);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = metric->Eval(predts, info, false);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Eval(predts, info, false), result);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
inline void TestMultiClassError(int device) {
|
||||
auto lparam = xgboost::CreateEmptyGenericParam(device);
|
||||
lparam.gpu_id = device;
|
||||
@ -17,12 +54,12 @@ inline void TestMultiClassError(int device) {
|
||||
{0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f},
|
||||
{0, 1, 2}),
|
||||
0.666f, 0.001f);
|
||||
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MultiClassError)) {
|
||||
TestMultiClassError(GPUIDX);
|
||||
xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"merror"}, GPUIDX);
|
||||
}
|
||||
|
||||
inline void TestMultiClassLogLoss(int device) {
|
||||
@ -44,6 +81,7 @@ inline void TestMultiClassLogLoss(int device) {
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) {
|
||||
TestMultiClassLogLoss(GPUIDX);
|
||||
xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"mlogloss"}, GPUIDX);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user