Initial support for multioutput regression. (#7514)

* Add num target model parameter, which is configured from input labels.
* Change elementwise metric and indexing for weights.
* Add demo.
* Add tests.
This commit is contained in:
Jiaming Yuan
2021-12-18 09:28:38 +08:00
committed by GitHub
parent 9ab73f737e
commit 58a6723eb1
22 changed files with 306 additions and 67 deletions

View File

@@ -56,6 +56,11 @@ class RegLossObj : public ObjFunction {
return Loss::Info();
}
uint32_t Targets(MetaInfo const& info) const override {
// Multi-target regression.
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int,
HostDeviceVector<GradientPair>* out_gpair) override {
@@ -70,7 +75,7 @@ class RegLossObj : public ObjFunction {
bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight;
@@ -83,8 +88,10 @@ class RegLossObj : public ObjFunction {
// for better performance.
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks);
auto const n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1));
common::Transform<>::Init(
[block_size, ndata] XGBOOST_DEVICE(
[block_size, ndata, n_targets] XGBOOST_DEVICE(
size_t data_block_idx, common::Span<float> _additional_input,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
@@ -101,7 +108,7 @@ class RegLossObj : public ObjFunction {
for (size_t idx = begin; idx < end; ++idx) {
bst_float p = Loss::PredTransform(preds_ptr[idx]);
bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx];
bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx / n_targets];
bst_float label = labels_ptr[idx];
if (label == 1.0f) {
w *= _scale_pos_weight;