Fix thread safety of softmax prediction. (#7104)
This commit is contained in:
parent
2801d69fb7
commit
abec3dbf6d
@ -53,7 +53,7 @@ class ObjFunction : public Configurable {
|
|||||||
* \brief transform prediction values, this is only called when Prediction is called
|
* \brief transform prediction values, this is only called when Prediction is called
|
||||||
* \param io_preds prediction values, saves to this vector as well
|
* \param io_preds prediction values, saves to this vector as well
|
||||||
*/
|
*/
|
||||||
virtual void PredTransform(HostDeviceVector<bst_float>*) {}
|
virtual void PredTransform(HostDeviceVector<bst_float>*) const {}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief transform prediction values, this is only called when Eval is called,
|
* \brief transform prediction values, this is only called when Eval is called,
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class MyLogistic : public ObjFunction {
|
|||||||
const char* DefaultEvalMetric() const override {
|
const char* DefaultEvalMetric() const override {
|
||||||
return "logloss";
|
return "logloss";
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
// transform margin value to probability.
|
// transform margin value to probability.
|
||||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||||
for (auto& pred : preds) {
|
for (auto& pred : preds) {
|
||||||
|
|||||||
@ -102,7 +102,7 @@ class AFTObj : public ObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
// Trees give us a prediction in log scale, so exponentiate
|
// Trees give us a prediction in log scale, so exponentiate
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||||
|
|||||||
@ -68,7 +68,7 @@ class HingeObj : public ObjFunction {
|
|||||||
out_gpair, &preds, &info.labels_, &info.weights_);
|
out_gpair, &preds, &info.labels_, &info.weights_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||||
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
|
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
|
||||||
|
|||||||
@ -121,7 +121,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
|
||||||
this->Transform(io_preds, output_prob_);
|
this->Transform(io_preds, output_prob_);
|
||||||
}
|
}
|
||||||
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
|
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||||
@ -131,10 +131,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
return "mlogloss";
|
return "mlogloss";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
|
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
|
||||||
const int nclass = param_.num_class;
|
const int nclass = param_.num_class;
|
||||||
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
|
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
|
||||||
max_preds_.Resize(ndata);
|
|
||||||
|
|
||||||
auto device = io_preds->DeviceIdx();
|
auto device = io_preds->DeviceIdx();
|
||||||
if (prob) {
|
if (prob) {
|
||||||
@ -148,23 +147,22 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
.Eval(io_preds);
|
.Eval(io_preds);
|
||||||
} else {
|
} else {
|
||||||
io_preds->SetDevice(device);
|
io_preds->SetDevice(device);
|
||||||
max_preds_.SetDevice(device);
|
HostDeviceVector<bst_float> max_preds;
|
||||||
|
max_preds.SetDevice(device);
|
||||||
|
max_preds.Resize(ndata);
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[=] XGBOOST_DEVICE(size_t _idx,
|
[=] XGBOOST_DEVICE(size_t _idx, common::Span<const bst_float> _preds,
|
||||||
common::Span<const bst_float> _preds,
|
|
||||||
common::Span<bst_float> _max_preds) {
|
common::Span<bst_float> _max_preds) {
|
||||||
common::Span<const bst_float> point =
|
common::Span<const bst_float> point =
|
||||||
_preds.subspan(_idx * nclass, nclass);
|
_preds.subspan(_idx * nclass, nclass);
|
||||||
_max_preds[_idx] =
|
_max_preds[_idx] =
|
||||||
common::FindMaxIndex(point.cbegin(),
|
common::FindMaxIndex(point.cbegin(), point.cend()) -
|
||||||
point.cend()) - point.cbegin();
|
point.cbegin();
|
||||||
},
|
},
|
||||||
common::Range{0, ndata}, device, false)
|
common::Range{0, ndata}, device, false)
|
||||||
.Eval(io_preds, &max_preds_);
|
.Eval(io_preds, &max_preds);
|
||||||
}
|
io_preds->Resize(max_preds.Size());
|
||||||
if (!prob) {
|
io_preds->Copy(max_preds);
|
||||||
io_preds->Resize(max_preds_.Size());
|
|
||||||
io_preds->Copy(max_preds_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,7 +186,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
// parameter
|
// parameter
|
||||||
SoftmaxMultiClassParam param_;
|
SoftmaxMultiClassParam param_;
|
||||||
// Cache for max_preds
|
// Cache for max_preds
|
||||||
HostDeviceVector<bst_float> max_preds_;
|
|
||||||
HostDeviceVector<int> label_correct_;
|
HostDeviceVector<int> label_correct_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class RegLossObj : public ObjFunction {
|
|||||||
return Loss::DefaultEvalMetric();
|
return Loss::DefaultEvalMetric();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
void PredTransform(HostDeviceVector<float> *io_preds) const override {
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
|
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
|
||||||
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
|
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
|
||||||
@ -233,7 +233,7 @@ class PoissonRegression : public ObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||||
_preds[_idx] = expf(_preds[_idx]);
|
_preds[_idx] = expf(_preds[_idx]);
|
||||||
@ -343,7 +343,7 @@ class CoxRegression : public ObjFunction {
|
|||||||
last_exp_p = exp_p;
|
last_exp_p = exp_p;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||||
common::ParallelFor(ndata, [&](long j) { // NOLINT(*)
|
common::ParallelFor(ndata, [&](long j) { // NOLINT(*)
|
||||||
@ -420,7 +420,7 @@ class GammaRegression : public ObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||||
_preds[_idx] = expf(_preds[_idx]);
|
_preds[_idx] = expf(_preds[_idx]);
|
||||||
@ -523,7 +523,7 @@ class TweedieRegression : public ObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
common::Transform<>::Init(
|
common::Transform<>::Init(
|
||||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||||
_preds[_idx] = expf(_preds[_idx]);
|
_preds[_idx] = expf(_preds[_idx]);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user