Fix thread safety of softmax prediction. (#7104)

This commit is contained in:
Jiaming Yuan 2021-07-16 02:06:55 +08:00 committed by GitHub
parent 2801d69fb7
commit abec3dbf6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 23 deletions

View File

@ -53,7 +53,7 @@ class ObjFunction : public Configurable {
* \brief transform prediction values, this is only called when Prediction is called * \brief transform prediction values, this is only called when Prediction is called
* \param io_preds prediction values, saves to this vector as well * \param io_preds prediction values, saves to this vector as well
*/ */
virtual void PredTransform(HostDeviceVector<bst_float>*) {} virtual void PredTransform(HostDeviceVector<bst_float>*) const {}
/*! /*!
* \brief transform prediction values, this is only called when Eval is called, * \brief transform prediction values, this is only called when Eval is called,

View File

@ -58,7 +58,7 @@ class MyLogistic : public ObjFunction {
const char* DefaultEvalMetric() const override { const char* DefaultEvalMetric() const override {
return "logloss"; return "logloss";
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
// transform margin value to probability. // transform margin value to probability.
std::vector<bst_float> &preds = io_preds->HostVector(); std::vector<bst_float> &preds = io_preds->HostVector();
for (auto& pred : preds) { for (auto& pred : preds) {

View File

@ -102,7 +102,7 @@ class AFTObj : public ObjFunction {
} }
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
// Trees give us a prediction in log scale, so exponentiate // Trees give us a prediction in log scale, so exponentiate
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {

View File

@ -68,7 +68,7 @@ class HingeObj : public ObjFunction {
out_gpair, &preds, &info.labels_, &info.weights_); out_gpair, &preds, &info.labels_, &info.weights_);
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;

View File

@ -121,7 +121,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
} }
} }
} }
void PredTransform(HostDeviceVector<bst_float>* io_preds) override { void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
this->Transform(io_preds, output_prob_); this->Transform(io_preds, output_prob_);
} }
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override { void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
@ -131,10 +131,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
return "mlogloss"; return "mlogloss";
} }
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) { inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
const int nclass = param_.num_class; const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass); const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata);
auto device = io_preds->DeviceIdx(); auto device = io_preds->DeviceIdx();
if (prob) { if (prob) {
@ -148,23 +147,22 @@ class SoftmaxMultiClassObj : public ObjFunction {
.Eval(io_preds); .Eval(io_preds);
} else { } else {
io_preds->SetDevice(device); io_preds->SetDevice(device);
max_preds_.SetDevice(device); HostDeviceVector<bst_float> max_preds;
max_preds.SetDevice(device);
max_preds.Resize(ndata);
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx, common::Span<const bst_float> _preds,
common::Span<const bst_float> _preds,
common::Span<bst_float> _max_preds) { common::Span<bst_float> _max_preds) {
common::Span<const bst_float> point = common::Span<const bst_float> point =
_preds.subspan(_idx * nclass, nclass); _preds.subspan(_idx * nclass, nclass);
_max_preds[_idx] = _max_preds[_idx] =
common::FindMaxIndex(point.cbegin(), common::FindMaxIndex(point.cbegin(), point.cend()) -
point.cend()) - point.cbegin(); point.cbegin();
}, },
common::Range{0, ndata}, device, false) common::Range{0, ndata}, device, false)
.Eval(io_preds, &max_preds_); .Eval(io_preds, &max_preds);
} io_preds->Resize(max_preds.Size());
if (!prob) { io_preds->Copy(max_preds);
io_preds->Resize(max_preds_.Size());
io_preds->Copy(max_preds_);
} }
} }
@ -188,7 +186,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
// parameter // parameter
SoftmaxMultiClassParam param_; SoftmaxMultiClassParam param_;
// Cache for max_preds // Cache for max_preds
HostDeviceVector<bst_float> max_preds_;
HostDeviceVector<int> label_correct_; HostDeviceVector<int> label_correct_;
}; };

View File

@ -109,7 +109,7 @@ class RegLossObj : public ObjFunction {
return Loss::DefaultEvalMetric(); return Loss::DefaultEvalMetric();
} }
void PredTransform(HostDeviceVector<float> *io_preds) override { void PredTransform(HostDeviceVector<float> *io_preds) const override {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
_preds[_idx] = Loss::PredTransform(_preds[_idx]); _preds[_idx] = Loss::PredTransform(_preds[_idx]);
@ -233,7 +233,7 @@ class PoissonRegression : public ObjFunction {
} }
} }
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
@ -343,7 +343,7 @@ class CoxRegression : public ObjFunction {
last_exp_p = exp_p; last_exp_p = exp_p;
} }
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
std::vector<bst_float> &preds = io_preds->HostVector(); std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*) const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
common::ParallelFor(ndata, [&](long j) { // NOLINT(*) common::ParallelFor(ndata, [&](long j) { // NOLINT(*)
@ -420,7 +420,7 @@ class GammaRegression : public ObjFunction {
} }
} }
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
@ -523,7 +523,7 @@ class TweedieRegression : public ObjFunction {
} }
} }
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);