Multi-target support for L1 error. (#8652)

- Add matrix support to the median function.
- Iterate through each target for quantile computation.
This commit is contained in:
Jiaming Yuan
2023-01-11 05:51:14 +08:00
committed by GitHub
parent badeff1d74
commit cfa994d57f
19 changed files with 430 additions and 215 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2016-2022 by XGBoost contributors
/**
* Copyright 2016-2023 by XGBoost contributors
*/
#include "helpers.h"
@@ -335,30 +335,30 @@ void RandomDataGenerator::GenerateCSR(
CHECK_EQ(columns->Size(), value->Size());
}
std::shared_ptr<DMatrix>
RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
size_t classes) const {
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
size_t classes) const {
HostDeviceVector<float> data;
HostDeviceVector<bst_row_t> rptrs;
HostDeviceVector<bst_feature_t> columns;
this->GenerateCSR(&data, &rptrs, &columns);
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(),
data.HostPointer(), rows_, data.Size(), cols_);
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_,
data.Size(), cols_);
std::shared_ptr<DMatrix> out{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
if (with_label) {
RandomDataGenerator gen(rows_, 1, 0);
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
if (!float_label) {
gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data());
out->Info().labels.Reshape(this->rows_);
out->Info().labels.Reshape(this->rows_, this->n_targets_);
auto& h_labels = out->Info().labels.Data()->HostVector();
for (auto& v : h_labels) {
v = static_cast<float>(static_cast<uint32_t>(v));
}
} else {
gen.GenerateDense(out->Info().labels.Data());
out->Info().labels.Reshape(this->rows_);
CHECK_EQ(out->Info().labels.Size(), this->rows_ * this->n_targets_);
out->Info().labels.Reshape(this->rows_, this->n_targets_);
}
}
if (device_ >= 0) {