Fix thread safety of softmax prediction. (#7104)

This commit is contained in:
Jiaming Yuan 2021-07-16 02:06:55 +08:00 committed by GitHub
parent 2801d69fb7
commit abec3dbf6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 23 deletions

View File

@ -53,7 +53,7 @@ class ObjFunction : public Configurable {
* \brief transform prediction values, this is only called when Prediction is called
* \param io_preds prediction values, saves to this vector as well
*/
virtual void PredTransform(HostDeviceVector<bst_float>*) {}
virtual void PredTransform(HostDeviceVector<bst_float>*) const {}
/*!
* \brief transform prediction values, this is only called when Eval is called,

View File

@ -58,7 +58,7 @@ class MyLogistic : public ObjFunction {
const char* DefaultEvalMetric() const override {
return "logloss";
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
// transform margin value to probability.
std::vector<bst_float> &preds = io_preds->HostVector();
for (auto& pred : preds) {

View File

@ -102,7 +102,7 @@ class AFTObj : public ObjFunction {
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
// Trees give us a prediction in log scale, so exponentiate
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {

View File

@ -68,7 +68,7 @@ class HingeObj : public ObjFunction {
out_gpair, &preds, &info.labels_, &info.weights_);
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;

View File

@ -121,7 +121,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
}
}
}
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
this->Transform(io_preds, output_prob_);
}
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
@ -131,10 +131,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
return "mlogloss";
}
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata);
auto device = io_preds->DeviceIdx();
if (prob) {
@ -148,23 +147,22 @@ class SoftmaxMultiClassObj : public ObjFunction {
.Eval(io_preds);
} else {
io_preds->SetDevice(device);
max_preds_.SetDevice(device);
HostDeviceVector<bst_float> max_preds;
max_preds.SetDevice(device);
max_preds.Resize(ndata);
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<const bst_float> _preds,
[=] XGBOOST_DEVICE(size_t _idx, common::Span<const bst_float> _preds,
common::Span<bst_float> _max_preds) {
common::Span<const bst_float> point =
_preds.subspan(_idx * nclass, nclass);
_max_preds[_idx] =
common::FindMaxIndex(point.cbegin(),
point.cend()) - point.cbegin();
common::FindMaxIndex(point.cbegin(), point.cend()) -
point.cbegin();
},
common::Range{0, ndata}, device, false)
.Eval(io_preds, &max_preds_);
}
if (!prob) {
io_preds->Resize(max_preds_.Size());
io_preds->Copy(max_preds_);
.Eval(io_preds, &max_preds);
io_preds->Resize(max_preds.Size());
io_preds->Copy(max_preds);
}
}
@ -188,7 +186,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
// parameter
SoftmaxMultiClassParam param_;
// Cache for max_preds
HostDeviceVector<bst_float> max_preds_;
HostDeviceVector<int> label_correct_;
};

View File

@ -109,7 +109,7 @@ class RegLossObj : public ObjFunction {
return Loss::DefaultEvalMetric();
}
void PredTransform(HostDeviceVector<float> *io_preds) override {
void PredTransform(HostDeviceVector<float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
@ -233,7 +233,7 @@ class PoissonRegression : public ObjFunction {
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);
@ -343,7 +343,7 @@ class CoxRegression : public ObjFunction {
last_exp_p = exp_p;
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
common::ParallelFor(ndata, [&](long j) { // NOLINT(*)
@ -420,7 +420,7 @@ class GammaRegression : public ObjFunction {
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);
@ -523,7 +523,7 @@ class TweedieRegression : public ObjFunction {
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);