Fix release degradation (#5720)
* fix release degradation, related to 5666 * less resizes Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com>
This commit is contained in:
parent
251dc8a663
commit
057c762ecd
@ -41,10 +41,11 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
|
||||
template<typename Loss>
|
||||
class RegLossObj : public ObjFunction {
|
||||
protected:
|
||||
HostDeviceVector<int> label_correct_;
|
||||
HostDeviceVector<float> additional_input_;
|
||||
|
||||
public:
|
||||
RegLossObj() = default;
|
||||
// 0 - label_correct flag, 1 - scale_pos_weight, 2 - is_null_weight
|
||||
RegLossObj(): additional_input_(3) {}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
@ -64,8 +65,7 @@ class RegLossObj : public ObjFunction {
|
||||
size_t const ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
auto device = tparam_->gpu_id;
|
||||
label_correct_.Resize(1);
|
||||
label_correct_.Fill(1);
|
||||
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
@ -73,35 +73,37 @@ class RegLossObj : public ObjFunction {
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
additional_input_.HostVector().begin()[1] = scale_pos_weight;
|
||||
additional_input_.HostVector().begin()[2] = is_null_weight;
|
||||
|
||||
common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<float> _additional_input,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
const float _scale_pos_weight = _additional_input[1];
|
||||
const bool _is_null_weight = _additional_input[2];
|
||||
|
||||
bst_float p = Loss::PredTransform(_preds[_idx]);
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float w = _is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float label = _labels[_idx];
|
||||
if (label == 1.0f) {
|
||||
w *= scale_pos_weight;
|
||||
w *= _scale_pos_weight;
|
||||
}
|
||||
if (!Loss::CheckLabel(label)) {
|
||||
// If there is an incorrect label, the host code will know.
|
||||
_label_correct[0] = 0;
|
||||
_additional_input[0] = 0;
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
|
||||
Loss::SecondOrderGradient(p, label) * w);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
|
||||
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
&additional_input_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
|
||||
// copy "label correct" flags back to host
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
}
|
||||
auto const flag = additional_input_.HostVector().begin()[0];
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user