GPU multiclass metrics (#4368)

* Port multi classes metrics to CUDA.
This commit is contained in:
Jiaming Yuan 2019-04-15 17:47:47 +08:00 committed by GitHub
parent be7bc07ca3
commit 84d992babc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 368 additions and 189 deletions

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2015-2018 by Contributors
* Copyright 2015-2019 by Contributors
* \file elementwise_metric.cc
* \brief evaluation metrics for elementwise binary or regression.
* \author Kailong Chen, Tianqi Chen
@ -9,15 +9,15 @@
#include <dmlc/registry.h>
#include <cmath>
#include "metric_param.h"
#include "metric_common.h"
#include "../common/math.h"
#include "../common/common.h"
#if defined(XGBOOST_USE_CUDA)
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_reduce.h>
#include <thrust/execution_policy.h>
#include <thrust/execution_policy.h> // thrust::cuda::par
#include <thrust/functional.h> // thrust::plus<>
#include <thrust/transform_reduce.h>
#include <thrust/iterator/counting_iterator.h>
#include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA
@ -28,29 +28,9 @@ namespace metric {
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
template <typename EvalRow>
class MetricsReduction {
class ElementWiseMetricsReduction {
public:
class PackedReduceResult {
double residue_sum_;
double weights_sum_;
friend MetricsReduction;
public:
XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {}
XGBOOST_DEVICE PackedReduceResult(double residue, double weight) :
residue_sum_{residue}, weights_sum_{weight} {}
XGBOOST_DEVICE
PackedReduceResult operator+(PackedReduceResult const& other) const {
return PackedReduceResult { residue_sum_ + other.residue_sum_,
weights_sum_ + other.weights_sum_ };
}
double Residue() const { return residue_sum_; }
double Weights() const { return weights_sum_; }
};
public:
explicit MetricsReduction(EvalRow policy) :
explicit ElementWiseMetricsReduction(EvalRow policy) :
policy_(std::move(policy)) {}
PackedReduceResult CpuReduceMetrics(
@ -144,9 +124,8 @@ class MetricsReduction {
DeviceReduceMetrics(id, index, weights, labels, preds);
}
for (size_t i = 0; i < devices.Size(); ++i) {
result.residue_sum_ += res_per_device[i].residue_sum_;
result.weights_sum_ += res_per_device[i].weights_sum_;
for (auto const& res : res_per_device) {
result += res;
}
}
#endif // defined(XGBOOST_USE_CUDA)
@ -370,7 +349,7 @@ struct EvalEWiseBase : public Metric {
MetricParam param_;
MetricsReduction<Policy> reducer_;
ElementWiseMetricsReduction<Policy> reducer_;
};
XGBOOST_REGISTER_METRIC(RMSE, "rmse")

View File

@ -6,7 +6,7 @@
#include <xgboost/metric.h>
#include <dmlc/registry.h>
#include "metric_param.h"
#include "metric_common.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);

View File

@ -0,0 +1,54 @@
/*!
* Copyright 2018-2019 by Contributors
* \file metric_param.cc
*/
#ifndef XGBOOST_METRIC_METRIC_COMMON_H_
#define XGBOOST_METRIC_METRIC_COMMON_H_
#include <dmlc/parameter.h>
#include "../common/common.h"
namespace xgboost {
namespace metric {
// Created exclusively for GPU.
struct MetricParam : public dmlc::Parameter<MetricParam> {
int n_gpus;
int gpu_id;
DMLC_DECLARE_PARAMETER(MetricParam) {
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
};
};
class PackedReduceResult {
double residue_sum_;
double weights_sum_;
public:
XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {}
XGBOOST_DEVICE PackedReduceResult(double residue, double weight)
: residue_sum_{residue}, weights_sum_{weight} {}
XGBOOST_DEVICE
PackedReduceResult operator+(PackedReduceResult const &other) const {
return PackedReduceResult{residue_sum_ + other.residue_sum_,
weights_sum_ + other.weights_sum_};
}
PackedReduceResult &operator+=(PackedReduceResult const &other) {
this->residue_sum_ += other.residue_sum_;
this->weights_sum_ += other.weights_sum_;
return *this;
}
double Residue() const { return residue_sum_; }
double Weights() const { return weights_sum_; }
};
} // namespace metric
} // namespace xgboost
#endif // XGBOOST_METRIC_METRIC_COMMON_H_

View File

@ -1,31 +0,0 @@
/*!
* Copyright 2018 by Contributors
* \file metric_param.cc
*/
#ifndef XGBOOST_METRIC_METRIC_PARAM_H_
#define XGBOOST_METRIC_METRIC_PARAM_H_
#include <dmlc/parameter.h>
#include "../common/common.h"
namespace xgboost {
namespace metric {
// Created exclusively for GPU.
struct MetricParam : public dmlc::Parameter<MetricParam> {
int n_gpus;
int gpu_id;
DMLC_DECLARE_PARAMETER(MetricParam) {
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
};
};
} // namespace metric
} // namespace xgboost
#endif // XGBOOST_METRIC_METRIC_PARAM_H_

View File

@ -1,126 +1,8 @@
/*!
* Copyright 2015 by Contributors
* \file multiclass_metric.cc
* \brief evaluation metrics for multiclass classification.
* \author Kailong Chen, Tianqi Chen
* Copyright 2019 XGBoost contributors
*/
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <cmath>
#include "../common/math.h"
// Dummy file to keep the CUDA conditional compile trick.
namespace xgboost {
namespace metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(multiclass_metric);
/*!
* \brief base class of multi-class evaluation
* \tparam Derived the name of subclass
*/
template<typename Derived>
struct EvalMClassBase : public Metric {
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK(preds.Size() % info.labels_.Size() == 0)
<< "label and prediction size not match";
const size_t nclass = preds.Size() / info.labels_.Size();
CHECK_GE(nclass, 1U)
<< "mlogloss and merror are only used for multi-class classification,"
<< " use logloss for binary classification";
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
double sum = 0.0, wsum = 0.0;
int label_error = 0;
const auto& labels = info.labels_.HostVector();
const auto& weights = info.weights_.HostVector();
const std::vector<bst_float>& h_preds = preds.HostVector();
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
auto label = static_cast<int>(labels[i]);
if (label >= 0 && label < static_cast<int>(nclass)) {
sum += Derived::EvalRow(label,
h_preds.data() + i * nclass,
nclass) * wt;
wsum += wt;
} else {
label_error = label;
}
}
CHECK(label_error >= 0 && label_error < static_cast<int>(nclass))
<< "MultiClassEvaluation: label must be in [0, num_class),"
<< " num_class=" << nclass << " but found " << label_error << " in label";
double dat[2]; dat[0] = sum, dat[1] = wsum;
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
return Derived::GetFinal(dat[0], dat[1]);
}
/*!
* \brief to be implemented by subclass,
* get evaluation result from one row
* \param label label of current instance
* \param pred prediction value of current instance
* \param nclass number of class in the prediction
*/
inline static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass);
/*!
* \brief to be overridden by subclass, final transformation
* \param esum the sum statistics returned by EvalRow
* \param wsum sum of weight
*/
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum;
}
private:
// used to store error message
const char *error_msg_;
};
/*! \brief match error */
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
const char* Name() const override {
return "merror";
}
inline static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass) {
return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast<int>(label);
}
};
/*! \brief match error */
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
const char* Name() const override {
return "mlogloss";
}
inline static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass) {
const bst_float eps = 1e-16f;
auto k = static_cast<size_t>(label);
if (pred[k] > eps) {
return -std::log(pred[k]);
} else {
return -std::log(eps);
}
}
};
XGBOOST_REGISTER_METRIC(MatchError, "merror")
.describe("Multiclass classification error.")
.set_body([](const char* param) { return new EvalMatchError(); });
XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
.describe("Multiclass negative loglikelihood.")
.set_body([](const char* param) { return new EvalMultiLogLoss(); });
} // namespace metric
} // namespace xgboost
#if !defined(XGBOOST_USE_CUDA)
#include "multiclass_metric.cu"
#endif // !defined(XGBOOST_USE_CUDA)

View File

@ -0,0 +1,261 @@
/*!
* Copyright 2015-2019 by Contributors
* \file multiclass_metric.cc
* \brief evaluation metrics for multiclass classification.
* \author Kailong Chen, Tianqi Chen
*/
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <cmath>
#include "metric_common.h"
#include "../common/math.h"
#include "../common/common.h"
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
#include <thrust/functional.h> // thrust::plus<>
#include <thrust/transform_reduce.h>
#include <thrust/iterator/counting_iterator.h>
#include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA
namespace xgboost {
namespace metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(multiclass_metric);
template <typename EvalRowPolicy>
class MultiClassMetricsReduction {
void CheckLabelError(int32_t label_error, size_t n_class) const {
CHECK(label_error >= 0 && label_error < static_cast<int32_t>(n_class))
<< "MultiClassEvaluation: label must be in [0, num_class),"
<< " num_class=" << n_class << " but found " << label_error << " in label";
}
public:
MultiClassMetricsReduction() = default;
PackedReduceResult CpuReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds,
const size_t n_class) const {
size_t ndata = labels.Size();
const auto& h_labels = labels.HostVector();
const auto& h_weights = weights.HostVector();
const auto& h_preds = preds.HostVector();
bst_float residue_sum = 0;
bst_float weights_sum = 0;
int label_error = 0;
bool const is_null_weight = weights.Size() == 0;
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
for (omp_ulong idx = 0; idx < ndata; ++idx) {
bst_float weight = is_null_weight ? 1.0f : h_weights[idx];
auto label = static_cast<int>(h_labels[idx]);
if (label >= 0 && label < static_cast<int>(n_class)) {
residue_sum += EvalRowPolicy::EvalRow(
label, h_preds.data() + idx * n_class, n_class) * weight;
weights_sum += weight;
} else {
label_error = label;
}
}
CheckLabelError(label_error, n_class);
PackedReduceResult res { residue_sum, weights_sum };
return res;
}
#if defined(XGBOOST_USE_CUDA)
PackedReduceResult DeviceReduceMetrics(
GPUSet::GpuIdType device_id,
size_t device_index,
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds,
const size_t n_class) {
size_t n_data = labels.DeviceSize(device_id);
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data;
auto s_labels = labels.DeviceSpan(device_id);
auto s_preds = preds.DeviceSpan(device_id);
auto s_weights = weights.DeviceSpan(device_id);
bool const is_null_weight = weights.Size() == 0;
auto s_label_error = label_error_.GetSpan<int32_t>(1);
s_label_error[0] = 0;
PackedReduceResult result = thrust::transform_reduce(
thrust::cuda::par(allocators_.at(device_index)),
begin, end,
[=] XGBOOST_DEVICE(size_t idx) {
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
bst_float residue = 0;
auto label = static_cast<int>(s_labels[idx]);
if (label >= 0 && label < static_cast<int32_t>(n_class)) {
residue = EvalRowPolicy::EvalRow(
label, &s_preds[idx * n_class], n_class) * weight;
} else {
s_label_error[0] = label;
}
return PackedReduceResult{ residue, weight };
},
PackedReduceResult(),
thrust::plus<PackedReduceResult>());
CheckLabelError(s_label_error[0], n_class);
return result;
}
#endif // XGBOOST_USE_CUDA
PackedReduceResult Reduce(
GPUSet devices,
size_t n_class,
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result;
if (devices.IsEmpty()) {
result = CpuReduceMetrics(weights, labels, preds, n_class);
}
#if defined(XGBOOST_USE_CUDA)
else { // NOLINT
if (allocators_.size() != devices.Size()) {
allocators_.clear();
allocators_.resize(devices.Size());
}
preds.Reshard(GPUDistribution::Granular(devices, n_class));
labels.Reshard(devices);
weights.Reshard(devices);
std::vector<PackedReduceResult> res_per_device(devices.Size());
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) {
dh::safe_cuda(cudaSetDevice(id));
size_t index = devices.Index(id);
res_per_device.at(index) =
DeviceReduceMetrics(id, index, weights, labels, preds, n_class);
}
for (auto const& res : res_per_device) {
result += res;
}
}
#endif // defined(XGBOOST_USE_CUDA)
return result;
}
private:
#if defined(XGBOOST_USE_CUDA)
dh::PinnedMemory label_error_;
std::vector<dh::CubMemory> allocators_;
#endif // defined(XGBOOST_USE_CUDA)
};
/*!
* \brief base class of multi-class evaluation
* \tparam Derived the name of subclass
*/
template<typename Derived>
struct EvalMClassBase : public Metric {
void Configure(
const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK(preds.Size() % info.labels_.Size() == 0)
<< "label and prediction size not match";
const size_t nclass = preds.Size() / info.labels_.Size();
CHECK_GE(nclass, 1U)
<< "mlogloss and merror are only used for multi-class classification,"
<< " use logloss for binary classification";
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata);
auto result = reducer_.Reduce(devices, nclass, info.weights_, info.labels_, preds);
double dat[2] { result.Residue(), result.Weights() };
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
return Derived::GetFinal(dat[0], dat[1]);
}
/*!
* \brief to be implemented by subclass,
* get evaluation result from one row
* \param label label of current instance
* \param pred prediction value of current instance
* \param nclass number of class in the prediction
*/
XGBOOST_DEVICE static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass);
/*!
* \brief to be overridden by subclass, final transformation
* \param esum the sum statistics returned by EvalRow
* \param wsum sum of weight
*/
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum;
}
private:
MultiClassMetricsReduction<Derived> reducer_;
MetricParam param_;
// used to store error message
const char *error_msg_;
};
/*! \brief match error */
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
const char* Name() const override {
return "merror";
}
XGBOOST_DEVICE static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass) {
return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast<int>(label);
}
};
/*! \brief match error */
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
const char* Name() const override {
return "mlogloss";
}
XGBOOST_DEVICE static bst_float EvalRow(int label,
const bst_float *pred,
size_t nclass) {
const bst_float eps = 1e-16f;
auto k = static_cast<size_t>(label);
if (pred[k] > eps) {
return -std::log(pred[k]);
} else {
return -std::log(eps);
}
}
};
XGBOOST_REGISTER_METRIC(MatchError, "merror")
.describe("Multiclass classification error.")
.set_body([](const char* param) { return new EvalMatchError(); });
XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
.describe("Multiclass negative loglikelihood.")
.set_body([](const char* param) { return new EvalMultiLogLoss(); });
} // namespace metric
} // namespace xgboost

View File

@ -1,10 +1,20 @@
// Copyright by Contributors
#include <xgboost/metric.h>
#include <string>
#include "../helpers.h"
TEST(Metric, MultiClassError) {
using Arg = std::pair<std::string, std::string>;
#if defined(__CUDACC__)
#define N_GPU() Arg{"n_gpus", "1"}
#else
#define N_GPU() Arg{"n_gpus", "0"}
#endif
inline void TestMultiClassError(std::vector<Arg> args) {
xgboost::Metric * metric = xgboost::Metric::Create("merror");
metric->Configure(args);
ASSERT_STREQ(metric->Name(), "merror");
EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0}));
EXPECT_NEAR(GetMetricEval(
@ -17,8 +27,13 @@ TEST(Metric, MultiClassError) {
delete metric;
}
TEST(Metric, MultiClassLogLoss) {
TEST(Metric, DeclareUnifiedTest(MultiClassError)) {
TestMultiClassError({N_GPU()});
}
inline void TestMultiClassLogLoss(std::vector<Arg> args) {
xgboost::Metric * metric = xgboost::Metric::Create("mlogloss");
metric->Configure(args);
ASSERT_STREQ(metric->Name(), "mlogloss");
EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0}));
EXPECT_NEAR(GetMetricEval(
@ -30,3 +45,17 @@ TEST(Metric, MultiClassLogLoss) {
delete metric;
}
TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) {
TestMultiClassLogLoss({N_GPU()});
}
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
TEST(Metric, MGPU_MultiClassError) {
TestMultiClassError({Arg{"n_gpus", "-1"}});
TestMultiClassError({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}});
TestMultiClassLogLoss({Arg{"n_gpus", "-1"}});
TestMultiClassLogLoss({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}});
}
#endif // defined(XGBOOST_USE_NCCL)

View File

@ -0,0 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
*/
// Dummy file to keep the CUDA conditional compile trick.
#include "test_multiclass_metric.cc"