diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index bb4e7074b..ccb3a723d 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -73,30 +73,47 @@ class RegLossObj : public ObjFunction { additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[2] = is_null_weight; - common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx, - common::Span _additional_input, - common::Span _out_gpair, - common::Span _preds, - common::Span _labels, - common::Span _weights) { + const size_t nthreads = tparam_->Threads(); + bool on_device = device >= 0; + // On CPU we run the transformation each thread processing a contigious block of data + // for better performance. + const size_t n_data_blocks = + std::max(static_cast(1), (on_device ? ndata : nthreads)); + const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks); + common::Transform<>::Init( + [block_size, ndata] XGBOOST_DEVICE( + size_t data_block_idx, common::Span _additional_input, + common::Span _out_gpair, + common::Span _preds, + common::Span _labels, + common::Span _weights) { + const bst_float* preds_ptr = _preds.data(); + const bst_float* labels_ptr = _labels.data(); + const bst_float* weights_ptr = _weights.data(); + GradientPair* out_gpair_ptr = _out_gpair.data(); + const size_t begin = data_block_idx*block_size; + const size_t end = std::min(ndata, begin + block_size); const float _scale_pos_weight = _additional_input[1]; const bool _is_null_weight = _additional_input[2]; - bst_float p = Loss::PredTransform(_preds[_idx]); - bst_float w = _is_null_weight ? 1.0f : _weights[_idx]; - bst_float label = _labels[_idx]; - if (label == 1.0f) { - w *= _scale_pos_weight; + for (size_t idx = begin; idx < end; ++idx) { + bst_float p = Loss::PredTransform(preds_ptr[idx]); + bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx]; + bst_float label = labels_ptr[idx]; + if (label == 1.0f) { + w *= _scale_pos_weight; + } + if (!Loss::CheckLabel(label)) { + // If there is an incorrect label, the host code will know. + _additional_input[0] = 0; + } + out_gpair_ptr[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, + Loss::SecondOrderGradient(p, label) * w); } - if (!Loss::CheckLabel(label)) { - // If there is an incorrect label, the host code will know. - _additional_input[0] = 0; - } - _out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, - Loss::SecondOrderGradient(p, label) * w); }, - common::Range{0, static_cast(ndata)}, device).Eval( - &additional_input_, out_gpair, &preds, &info.labels_, &info.weights_); + common::Range{0, static_cast(n_data_blocks)}, device) + .Eval(&additional_input_, out_gpair, &preds, &info.labels_, + &info.weights_); auto const flag = additional_input_.HostVector().begin()[0]; if (flag == 0) {