Initial support for multioutput regression. (#7514)
* Add num target model parameter, which is configured from input labels. * Change elementwise metric and indexing for weights. * Add demo. * Add tests.
This commit is contained in:
@@ -56,6 +56,11 @@ class RegLossObj : public ObjFunction {
|
||||
return Loss::Info();
|
||||
}
|
||||
|
||||
uint32_t Targets(MetaInfo const& info) const override {
|
||||
// Multi-target regression.
|
||||
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
@@ -70,7 +75,7 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
@@ -83,8 +88,10 @@ class RegLossObj : public ObjFunction {
|
||||
// for better performance.
|
||||
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
|
||||
const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks);
|
||||
auto const n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1));
|
||||
|
||||
common::Transform<>::Init(
|
||||
[block_size, ndata] XGBOOST_DEVICE(
|
||||
[block_size, ndata, n_targets] XGBOOST_DEVICE(
|
||||
size_t data_block_idx, common::Span<float> _additional_input,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
@@ -101,7 +108,7 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
for (size_t idx = begin; idx < end; ++idx) {
|
||||
bst_float p = Loss::PredTransform(preds_ptr[idx]);
|
||||
bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx];
|
||||
bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx / n_targets];
|
||||
bst_float label = labels_ptr[idx];
|
||||
if (label == 1.0f) {
|
||||
w *= _scale_pos_weight;
|
||||
|
||||
Reference in New Issue
Block a user