Fix thread safety of softmax prediction. (#7104)
This commit is contained in:
parent
2801d69fb7
commit
abec3dbf6d
@ -53,7 +53,7 @@ class ObjFunction : public Configurable {
|
||||
* \brief transform prediction values, this is only called when Prediction is called
|
||||
* \param io_preds prediction values, saves to this vector as well
|
||||
*/
|
||||
virtual void PredTransform(HostDeviceVector<bst_float>*) {}
|
||||
virtual void PredTransform(HostDeviceVector<bst_float>*) const {}
|
||||
|
||||
/*!
|
||||
* \brief transform prediction values, this is only called when Eval is called,
|
||||
|
||||
@ -58,7 +58,7 @@ class MyLogistic : public ObjFunction {
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "logloss";
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
// transform margin value to probability.
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
for (auto& pred : preds) {
|
||||
|
||||
@ -102,7 +102,7 @@ class AFTObj : public ObjFunction {
|
||||
}
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
// Trees give us a prediction in log scale, so exponentiate
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
|
||||
@ -68,7 +68,7 @@ class HingeObj : public ObjFunction {
|
||||
out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
|
||||
|
||||
@ -121,7 +121,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
|
||||
this->Transform(io_preds, output_prob_);
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
@ -131,10 +131,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
return "mlogloss";
|
||||
}
|
||||
|
||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
|
||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
|
||||
max_preds_.Resize(ndata);
|
||||
|
||||
auto device = io_preds->DeviceIdx();
|
||||
if (prob) {
|
||||
@ -148,23 +147,22 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
.Eval(io_preds);
|
||||
} else {
|
||||
io_preds->SetDevice(device);
|
||||
max_preds_.SetDevice(device);
|
||||
HostDeviceVector<bst_float> max_preds;
|
||||
max_preds.SetDevice(device);
|
||||
max_preds.Resize(ndata);
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<const bst_float> _preds,
|
||||
[=] XGBOOST_DEVICE(size_t _idx, common::Span<const bst_float> _preds,
|
||||
common::Span<bst_float> _max_preds) {
|
||||
common::Span<const bst_float> point =
|
||||
_preds.subspan(_idx * nclass, nclass);
|
||||
_max_preds[_idx] =
|
||||
common::FindMaxIndex(point.cbegin(),
|
||||
point.cend()) - point.cbegin();
|
||||
common::FindMaxIndex(point.cbegin(), point.cend()) -
|
||||
point.cbegin();
|
||||
},
|
||||
common::Range{0, ndata}, device, false)
|
||||
.Eval(io_preds, &max_preds_);
|
||||
}
|
||||
if (!prob) {
|
||||
io_preds->Resize(max_preds_.Size());
|
||||
io_preds->Copy(max_preds_);
|
||||
.Eval(io_preds, &max_preds);
|
||||
io_preds->Resize(max_preds.Size());
|
||||
io_preds->Copy(max_preds);
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,7 +186,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
// parameter
|
||||
SoftmaxMultiClassParam param_;
|
||||
// Cache for max_preds
|
||||
HostDeviceVector<bst_float> max_preds_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
};
|
||||
|
||||
|
||||
@ -109,7 +109,7 @@ class RegLossObj : public ObjFunction {
|
||||
return Loss::DefaultEvalMetric();
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) const override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
|
||||
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
|
||||
@ -233,7 +233,7 @@ class PoissonRegression : public ObjFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = expf(_preds[_idx]);
|
||||
@ -343,7 +343,7 @@ class CoxRegression : public ObjFunction {
|
||||
last_exp_p = exp_p;
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
common::ParallelFor(ndata, [&](long j) { // NOLINT(*)
|
||||
@ -420,7 +420,7 @@ class GammaRegression : public ObjFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = expf(_preds[_idx]);
|
||||
@ -523,7 +523,7 @@ class TweedieRegression : public ObjFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = expf(_preds[_idx]);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user