parent
be7bc07ca3
commit
84d992babc
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* Copyright 2015-2019 by Contributors
|
||||
* \file elementwise_metric.cc
|
||||
* \brief evaluation metrics for elementwise binary or regression.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
@ -9,15 +9,15 @@
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "metric_param.h"
|
||||
#include "metric_common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
#include <thrust/functional.h> // thrust::plus<>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
@ -28,29 +28,9 @@ namespace metric {
|
||||
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
|
||||
|
||||
template <typename EvalRow>
|
||||
class MetricsReduction {
|
||||
class ElementWiseMetricsReduction {
|
||||
public:
|
||||
class PackedReduceResult {
|
||||
double residue_sum_;
|
||||
double weights_sum_;
|
||||
friend MetricsReduction;
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {}
|
||||
XGBOOST_DEVICE PackedReduceResult(double residue, double weight) :
|
||||
residue_sum_{residue}, weights_sum_{weight} {}
|
||||
|
||||
XGBOOST_DEVICE
|
||||
PackedReduceResult operator+(PackedReduceResult const& other) const {
|
||||
return PackedReduceResult { residue_sum_ + other.residue_sum_,
|
||||
weights_sum_ + other.weights_sum_ };
|
||||
}
|
||||
double Residue() const { return residue_sum_; }
|
||||
double Weights() const { return weights_sum_; }
|
||||
};
|
||||
|
||||
public:
|
||||
explicit MetricsReduction(EvalRow policy) :
|
||||
explicit ElementWiseMetricsReduction(EvalRow policy) :
|
||||
policy_(std::move(policy)) {}
|
||||
|
||||
PackedReduceResult CpuReduceMetrics(
|
||||
@ -144,9 +124,8 @@ class MetricsReduction {
|
||||
DeviceReduceMetrics(id, index, weights, labels, preds);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||
result.residue_sum_ += res_per_device[i].residue_sum_;
|
||||
result.weights_sum_ += res_per_device[i].weights_sum_;
|
||||
for (auto const& res : res_per_device) {
|
||||
result += res;
|
||||
}
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
@ -370,7 +349,7 @@ struct EvalEWiseBase : public Metric {
|
||||
|
||||
MetricParam param_;
|
||||
|
||||
MetricsReduction<Policy> reducer_;
|
||||
ElementWiseMetricsReduction<Policy> reducer_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "metric_param.h"
|
||||
#include "metric_common.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
|
||||
54
src/metric/metric_common.h
Normal file
54
src/metric/metric_common.h
Normal file
@ -0,0 +1,54 @@
|
||||
/*!
|
||||
* Copyright 2018-2019 by Contributors
|
||||
* \file metric_param.cc
|
||||
*/
|
||||
#ifndef XGBOOST_METRIC_METRIC_COMMON_H_
|
||||
#define XGBOOST_METRIC_METRIC_COMMON_H_
|
||||
|
||||
#include <dmlc/parameter.h>
|
||||
#include "../common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
|
||||
// Created exclusively for GPU.
|
||||
struct MetricParam : public dmlc::Parameter<MetricParam> {
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(MetricParam) {
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
};
|
||||
};
|
||||
|
||||
class PackedReduceResult {
|
||||
double residue_sum_;
|
||||
double weights_sum_;
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {}
|
||||
XGBOOST_DEVICE PackedReduceResult(double residue, double weight)
|
||||
: residue_sum_{residue}, weights_sum_{weight} {}
|
||||
|
||||
XGBOOST_DEVICE
|
||||
PackedReduceResult operator+(PackedReduceResult const &other) const {
|
||||
return PackedReduceResult{residue_sum_ + other.residue_sum_,
|
||||
weights_sum_ + other.weights_sum_};
|
||||
}
|
||||
PackedReduceResult &operator+=(PackedReduceResult const &other) {
|
||||
this->residue_sum_ += other.residue_sum_;
|
||||
this->weights_sum_ += other.weights_sum_;
|
||||
return *this;
|
||||
}
|
||||
double Residue() const { return residue_sum_; }
|
||||
double Weights() const { return weights_sum_; }
|
||||
};
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_METRIC_METRIC_COMMON_H_
|
||||
@ -1,31 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file metric_param.cc
|
||||
*/
|
||||
#ifndef XGBOOST_METRIC_METRIC_PARAM_H_
|
||||
#define XGBOOST_METRIC_METRIC_PARAM_H_
|
||||
|
||||
#include <dmlc/parameter.h>
|
||||
#include "../common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
|
||||
// Created exclusively for GPU.
|
||||
struct MetricParam : public dmlc::Parameter<MetricParam> {
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(MetricParam) {
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_METRIC_METRIC_PARAM_H_
|
||||
@ -1,126 +1,8 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file multiclass_metric.cc
|
||||
* \brief evaluation metrics for multiclass classification.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_metric);
|
||||
|
||||
/*!
|
||||
* \brief base class of multi-class evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public Metric {
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0)
|
||||
<< "label and prediction size not match";
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
|
||||
double sum = 0.0, wsum = 0.0;
|
||||
int label_error = 0;
|
||||
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
const auto& weights = info.weights_.HostVector();
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
|
||||
auto label = static_cast<int>(labels[i]);
|
||||
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||
sum += Derived::EvalRow(label,
|
||||
h_preds.data() + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
} else {
|
||||
label_error = label;
|
||||
}
|
||||
}
|
||||
CHECK(label_error >= 0 && label_error < static_cast<int>(nclass))
|
||||
<< "MultiClassEvaluation: label must be in [0, num_class),"
|
||||
<< " num_class=" << nclass << " but found " << label_error << " in label";
|
||||
|
||||
double dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
* \param nclass number of class in the prediction
|
||||
*/
|
||||
inline static bst_float EvalRow(int label,
|
||||
const bst_float *pred,
|
||||
size_t nclass);
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
|
||||
private:
|
||||
// used to store error message
|
||||
const char *error_msg_;
|
||||
};
|
||||
|
||||
/*! \brief match error */
|
||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||
const char* Name() const override {
|
||||
return "merror";
|
||||
}
|
||||
inline static bst_float EvalRow(int label,
|
||||
const bst_float *pred,
|
||||
size_t nclass) {
|
||||
return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast<int>(label);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief match error */
|
||||
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||
const char* Name() const override {
|
||||
return "mlogloss";
|
||||
}
|
||||
inline static bst_float EvalRow(int label,
|
||||
const bst_float *pred,
|
||||
size_t nclass) {
|
||||
const bst_float eps = 1e-16f;
|
||||
auto k = static_cast<size_t>(label);
|
||||
if (pred[k] > eps) {
|
||||
return -std::log(pred[k]);
|
||||
} else {
|
||||
return -std::log(eps);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MatchError, "merror")
|
||||
.describe("Multiclass classification error.")
|
||||
.set_body([](const char* param) { return new EvalMatchError(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
|
||||
.describe("Multiclass negative loglikelihood.")
|
||||
.set_body([](const char* param) { return new EvalMultiLogLoss(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#include "multiclass_metric.cu"
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
261
src/metric/multiclass_metric.cu
Normal file
261
src/metric/multiclass_metric.cu
Normal file
@ -0,0 +1,261 @@
|
||||
/*!
|
||||
* Copyright 2015-2019 by Contributors
|
||||
* \file multiclass_metric.cc
|
||||
* \brief evaluation metrics for multiclass classification.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
#include <thrust/functional.h> // thrust::plus<>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_metric);
|
||||
|
||||
template <typename EvalRowPolicy>
|
||||
class MultiClassMetricsReduction {
|
||||
void CheckLabelError(int32_t label_error, size_t n_class) const {
|
||||
CHECK(label_error >= 0 && label_error < static_cast<int32_t>(n_class))
|
||||
<< "MultiClassEvaluation: label must be in [0, num_class),"
|
||||
<< " num_class=" << n_class << " but found " << label_error << " in label";
|
||||
}
|
||||
|
||||
public:
|
||||
MultiClassMetricsReduction() = default;
|
||||
|
||||
PackedReduceResult CpuReduceMetrics(
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds,
|
||||
const size_t n_class) const {
|
||||
size_t ndata = labels.Size();
|
||||
|
||||
const auto& h_labels = labels.HostVector();
|
||||
const auto& h_weights = weights.HostVector();
|
||||
const auto& h_preds = preds.HostVector();
|
||||
|
||||
bst_float residue_sum = 0;
|
||||
bst_float weights_sum = 0;
|
||||
int label_error = 0;
|
||||
bool const is_null_weight = weights.Size() == 0;
|
||||
|
||||
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
|
||||
for (omp_ulong idx = 0; idx < ndata; ++idx) {
|
||||
bst_float weight = is_null_weight ? 1.0f : h_weights[idx];
|
||||
auto label = static_cast<int>(h_labels[idx]);
|
||||
if (label >= 0 && label < static_cast<int>(n_class)) {
|
||||
residue_sum += EvalRowPolicy::EvalRow(
|
||||
label, h_preds.data() + idx * n_class, n_class) * weight;
|
||||
weights_sum += weight;
|
||||
} else {
|
||||
label_error = label;
|
||||
}
|
||||
}
|
||||
CheckLabelError(label_error, n_class);
|
||||
PackedReduceResult res { residue_sum, weights_sum };
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
|
||||
PackedReduceResult DeviceReduceMetrics(
|
||||
GPUSet::GpuIdType device_id,
|
||||
size_t device_index,
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds,
|
||||
const size_t n_class) {
|
||||
size_t n_data = labels.DeviceSize(device_id);
|
||||
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + n_data;
|
||||
|
||||
auto s_labels = labels.DeviceSpan(device_id);
|
||||
auto s_preds = preds.DeviceSpan(device_id);
|
||||
auto s_weights = weights.DeviceSpan(device_id);
|
||||
|
||||
bool const is_null_weight = weights.Size() == 0;
|
||||
auto s_label_error = label_error_.GetSpan<int32_t>(1);
|
||||
s_label_error[0] = 0;
|
||||
|
||||
PackedReduceResult result = thrust::transform_reduce(
|
||||
thrust::cuda::par(allocators_.at(device_index)),
|
||||
begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t idx) {
|
||||
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
||||
bst_float residue = 0;
|
||||
auto label = static_cast<int>(s_labels[idx]);
|
||||
if (label >= 0 && label < static_cast<int32_t>(n_class)) {
|
||||
residue = EvalRowPolicy::EvalRow(
|
||||
label, &s_preds[idx * n_class], n_class) * weight;
|
||||
} else {
|
||||
s_label_error[0] = label;
|
||||
}
|
||||
return PackedReduceResult{ residue, weight };
|
||||
},
|
||||
PackedReduceResult(),
|
||||
thrust::plus<PackedReduceResult>());
|
||||
CheckLabelError(s_label_error[0], n_class);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
PackedReduceResult Reduce(
|
||||
GPUSet devices,
|
||||
size_t n_class,
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
PackedReduceResult result;
|
||||
|
||||
if (devices.IsEmpty()) {
|
||||
result = CpuReduceMetrics(weights, labels, preds, n_class);
|
||||
}
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
else { // NOLINT
|
||||
if (allocators_.size() != devices.Size()) {
|
||||
allocators_.clear();
|
||||
allocators_.resize(devices.Size());
|
||||
}
|
||||
preds.Reshard(GPUDistribution::Granular(devices, n_class));
|
||||
labels.Reshard(devices);
|
||||
weights.Reshard(devices);
|
||||
std::vector<PackedReduceResult> res_per_device(devices.Size());
|
||||
|
||||
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
|
||||
for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) {
|
||||
dh::safe_cuda(cudaSetDevice(id));
|
||||
size_t index = devices.Index(id);
|
||||
res_per_device.at(index) =
|
||||
DeviceReduceMetrics(id, index, weights, labels, preds, n_class);
|
||||
}
|
||||
|
||||
for (auto const& res : res_per_device) {
|
||||
result += res;
|
||||
}
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::PinnedMemory label_error_;
|
||||
std::vector<dh::CubMemory> allocators_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief base class of multi-class evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public Metric {
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0)
|
||||
<< "label and prediction size not match";
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
|
||||
|
||||
GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata);
|
||||
auto result = reducer_.Reduce(devices, nclass, info.weights_, info.labels_, preds);
|
||||
double dat[2] { result.Residue(), result.Weights() };
|
||||
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
* \param nclass number of class in the prediction
|
||||
*/
|
||||
XGBOOST_DEVICE static bst_float EvalRow(int label,
|
||||
const bst_float *pred,
|
||||
size_t nclass);
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
|
||||
private:
|
||||
MultiClassMetricsReduction<Derived> reducer_;
|
||||
MetricParam param_;
|
||||
// used to store error message
|
||||
const char *error_msg_;
|
||||
};
|
||||
|
||||
/*! \brief match error */
|
||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||
const char* Name() const override {
|
||||
return "merror";
|
||||
}
|
||||
XGBOOST_DEVICE static bst_float EvalRow(int label,
|
||||
const bst_float *pred,
|
||||
size_t nclass) {
|
||||
return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast<int>(label);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief match error */
|
||||
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||
const char* Name() const override {
|
||||
return "mlogloss";
|
||||
}
|
||||
XGBOOST_DEVICE static bst_float EvalRow(int label,
|
||||
const bst_float *pred,
|
||||
size_t nclass) {
|
||||
const bst_float eps = 1e-16f;
|
||||
auto k = static_cast<size_t>(label);
|
||||
if (pred[k] > eps) {
|
||||
return -std::log(pred[k]);
|
||||
} else {
|
||||
return -std::log(eps);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MatchError, "merror")
|
||||
.describe("Multiclass classification error.")
|
||||
.set_body([](const char* param) { return new EvalMatchError(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
|
||||
.describe("Multiclass negative loglikelihood.")
|
||||
.set_body([](const char* param) { return new EvalMultiLogLoss(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
@ -1,10 +1,20 @@
|
||||
// Copyright by Contributors
|
||||
#include <xgboost/metric.h>
|
||||
#include <string>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Metric, MultiClassError) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define N_GPU() Arg{"n_gpus", "1"}
|
||||
#else
|
||||
#define N_GPU() Arg{"n_gpus", "0"}
|
||||
#endif
|
||||
|
||||
inline void TestMultiClassError(std::vector<Arg> args) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("merror");
|
||||
metric->Configure(args);
|
||||
ASSERT_STREQ(metric->Name(), "merror");
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0}));
|
||||
EXPECT_NEAR(GetMetricEval(
|
||||
@ -17,8 +27,13 @@ TEST(Metric, MultiClassError) {
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, MultiClassLogLoss) {
|
||||
TEST(Metric, DeclareUnifiedTest(MultiClassError)) {
|
||||
TestMultiClassError({N_GPU()});
|
||||
}
|
||||
|
||||
inline void TestMultiClassLogLoss(std::vector<Arg> args) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("mlogloss");
|
||||
metric->Configure(args);
|
||||
ASSERT_STREQ(metric->Name(), "mlogloss");
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0}));
|
||||
EXPECT_NEAR(GetMetricEval(
|
||||
@ -30,3 +45,17 @@ TEST(Metric, MultiClassLogLoss) {
|
||||
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) {
|
||||
TestMultiClassLogLoss({N_GPU()});
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
|
||||
TEST(Metric, MGPU_MultiClassError) {
|
||||
TestMultiClassError({Arg{"n_gpus", "-1"}});
|
||||
TestMultiClassError({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}});
|
||||
|
||||
TestMultiClassLogLoss({Arg{"n_gpus", "-1"}});
|
||||
TestMultiClassLogLoss({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}});
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
|
||||
5
tests/cpp/metric/test_multiclass_metric.cu
Normal file
5
tests/cpp/metric/test_multiclass_metric.cu
Normal file
@ -0,0 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*/
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
#include "test_multiclass_metric.cc"
|
||||
Loading…
x
Reference in New Issue
Block a user