From 58a6723eb10e18336ada3b9ea32c2fd9a2aee1c6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Dec 2021 09:28:38 +0800 Subject: [PATCH] Initial support for multioutput regression. (#7514) * Add num target model parameter, which is configured from input labels. * Change elementwise metric and indexing for weights. * Add demo. * Add tests. --- demo/guide-python/multioutput_regression.py | 48 ++++++++++++++ include/xgboost/objective.h | 10 +++ python-package/xgboost/data.py | 2 +- src/data/data.cc | 7 ++- src/data/data.cu | 1 + src/learner.cc | 69 ++++++++++++++++++--- src/metric/elementwise_metric.cu | 42 +++++++------ src/objective/regression_obj.cu | 13 +++- tests/cpp/c_api/test_c_api.cc | 2 + tests/cpp/data/test_metainfo.h | 3 +- tests/cpp/helpers.cc | 22 +++++-- tests/cpp/helpers.h | 6 ++ tests/cpp/metric/test_elementwise_metric.cc | 27 ++++++++ tests/cpp/test_learner.cc | 27 +++++++- tests/python-gpu/test_from_cudf.py | 3 +- tests/python-gpu/test_from_cupy.py | 5 +- tests/python-gpu/test_gpu_with_dask.py | 4 +- tests/python-gpu/test_gpu_with_sklearn.py | 8 ++- tests/python/test_demos.py | 8 +++ tests/python/test_with_dask.py | 6 +- tests/python/test_with_sklearn.py | 20 +++--- tests/python/testing.py | 40 +++++++++--- 22 files changed, 306 insertions(+), 67 deletions(-) create mode 100644 demo/guide-python/multioutput_regression.py diff --git a/demo/guide-python/multioutput_regression.py b/demo/guide-python/multioutput_regression.py new file mode 100644 index 000000000..a0d0998e6 --- /dev/null +++ b/demo/guide-python/multioutput_regression.py @@ -0,0 +1,48 @@ +""" +A demo for multi-output regression +================================== + +The demo is adopted from scikit-learn: + +https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py +""" +import numpy as np +import xgboost as xgb +import argparse +from matplotlib import pyplot as plt + + +def plot_predt(y, y_predt, name): + s = 25 + plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, + edgecolor="black", label="data") + plt.scatter(y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s, + edgecolor="black") + plt.xlim([-1, 2]) + plt.ylim([-1, 2]) + plt.show() + + +def main(plot_result: bool): + """Draw a circle with 2-dim coordinate as target variables.""" + rng = np.random.RandomState(1994) + X = np.sort(200 * rng.rand(100, 1) - 100, axis=0) + y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T + y[::5, :] += (0.5 - rng.rand(20, 2)) + y = y - y.min() + y = y / y.max() + + # Train a regressor on it + reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64) + reg.fit(X, y, eval_set=[(X, y)]) + + y_predt = reg.predict(X) + if plot_result: + plot_predt(y, y_predt, 'multi') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--plot", choices=[0, 1], type=int, default=1) + args = parser.parse_args() + main(args.plot == 1) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 3cf85c41d..40db951b4 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -77,6 +77,16 @@ class ObjFunction : public Configurable { * \brief Return task of this objective. */ virtual struct ObjInfo Task() const = 0; + /** + * \brief Return number of targets for input matrix. Right now XGBoost supports only + * multi-target regression. + */ + virtual uint32_t Targets(MetaInfo const& info) const { + if (info.labels.Shape(1) > 1) { + LOG(FATAL) << "multioutput is not supported by current objective function"; + } + return 1; + } /*! * \brief Create an objective function according to name. diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index b59b380b2..bfc4cf2d8 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -21,7 +21,7 @@ CAT_T = "c" # meta info that can be a matrix instead of vector. # For now it's base_margin for multi-class, but it can be extended to label once we have # multi-output. -_matrix_meta = {"base_margin"} +_matrix_meta = {"base_margin", "label"} def _warn_unused_missing(data, missing): diff --git a/src/data/data.cc b/src/data/data.cc index fa5b388ea..2de5bc8d4 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -160,10 +160,11 @@ void LoadTensorField(dmlc::Stream* strm, std::string const& expected_name, CHECK(strm->Read(&is_scalar)) << invalid; CHECK(!is_scalar) << invalid << "Expected field " << expected_name << " to be a tensor; got a scalar"; - std::array shape; + size_t shape[D]; for (size_t i = 0; i < D; ++i) { CHECK(strm->Read(&(shape[i]))); } + p_out->Reshape(shape); auto& field = p_out->Data()->HostVector(); CHECK(strm->Read(&field)) << invalid; } @@ -411,6 +412,7 @@ template void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { ArrayInterface array{arr_interface}; if (array.n == 0) { + p_out->Reshape(array.shape); return; } CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; @@ -737,8 +739,7 @@ void MetaInfo::Validate(int32_t device) const { return; } if (labels.Size() != 0) { - CHECK_EQ(labels.Size(), num_row_) - << "Size of labels must equal to number of rows."; + CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows."; check_device(*labels.Data()); return; } diff --git a/src/data/data.cu b/src/data/data.cu index aada91a62..475d70313 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -29,6 +29,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { ArrayInterface array(arr_interface); if (array.n == 0) { p_out->SetDevice(0); + p_out->Reshape(array.shape); return; } CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; diff --git a/src/learner.cc b/src/learner.cc index 9ba953fd7..965978412 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -88,12 +88,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter /*! \brief the version of XGBoost. */ uint32_t major_version; uint32_t minor_version; + + uint32_t num_target{1}; /*! \brief reserved field */ - int reserved[27]; + int reserved[26]; /*! \brief constructor */ LearnerModelParamLegacy() { std::memset(this, 0, sizeof(LearnerModelParamLegacy)); base_score = 0.5f; + num_target = 1; major_version = std::get<0>(Version::Self()); minor_version = std::get<1>(Version::Self()); static_assert(sizeof(LearnerModelParamLegacy) == 136, @@ -119,6 +122,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter CHECK(ret.ec == std::errc()); obj["num_class"] = std::string{integers, static_cast(std::distance(integers, ret.ptr))}; + + ret = to_chars(integers, integers + NumericLimits::kToCharsSize, + static_cast(num_target)); + obj["num_target"] = + std::string{integers, static_cast(std::distance(integers, ret.ptr))}; + return Json(std::move(obj)); } void FromJson(Json const& obj) { @@ -126,6 +135,11 @@ struct LearnerModelParamLegacy : public dmlc::Parameter std::map m; m["num_feature"] = get(j_param.at("num_feature")); m["num_class"] = get(j_param.at("num_class")); + auto n_targets_it = j_param.find("num_target"); + if (n_targets_it != j_param.cend()) { + m["num_target"] = get(n_targets_it->second); + } + this->Init(m); std::string str = get(j_param.at("base_score")); from_chars(str.c_str(), str.c_str() + str.size(), base_score); @@ -139,6 +153,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter dmlc::ByteSwap(&x.contain_eval_metrics, sizeof(x.contain_eval_metrics), 1); dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1); dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1); + dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1); dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0])); return x; } @@ -156,15 +171,24 @@ struct LearnerModelParamLegacy : public dmlc::Parameter DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe( "Number of class option for multi-class classifier. " " By default equals 0 and corresponds to binary classifier."); + DMLC_DECLARE_FIELD(num_target) + .set_default(1) + .set_lower_bound(1) + .describe("Number of target for multi-target regression."); } }; LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t) - : base_score{base_margin}, - num_feature{user_param.num_feature}, - num_output_group{user_param.num_class == 0 ? 1 : static_cast(user_param.num_class)}, - task{t} {} + : base_score{base_margin}, num_feature{user_param.num_feature}, task{t} { + auto n_classes = std::max(static_cast(user_param.num_class), 1u); + auto n_targets = user_param.num_target; + num_output_group = std::max(n_classes, n_targets); + // For version < 1.6, n_targets == 0 + CHECK(n_classes <= 1 || n_targets <= 1) + << "Multi-class multi-output is not yet supported. n_classes:" << n_classes + << ", n_targets:" << n_targets; +} struct LearnerTrainParam : public XGBoostParameter { // data split mode, can be row, col, or none. @@ -325,6 +349,8 @@ class LearnerConfiguration : public Learner { args = {cfg_.cbegin(), cfg_.cend()}; // renew this->ConfigureObjective(old_tparam, &args); + auto task = this->ConfigureTargets(); + // Before 1.0.0, we save `base_score` into binary as a transformed value by objective. // After 1.0.0 we save the value provided by user and keep it immutable instead. To // keep the stability, we initialize it in binary LoadModel instead of configuration. @@ -339,7 +365,7 @@ class LearnerConfiguration : public Learner { // - model is configured second time due to change of parameter if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) { learner_model_param_ = - LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task()); + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), task); } this->ConfigureGBM(old_tparam, args); @@ -586,8 +612,7 @@ class LearnerConfiguration : public Learner { CHECK(matrix.first); CHECK(!matrix.second.ref.expired()); const uint64_t num_col = matrix.first->Info().num_col_; - CHECK_LE(num_col, - static_cast(std::numeric_limits::max())) + CHECK_LE(num_col, static_cast(std::numeric_limits::max())) << "Unfortunately, XGBoost does not support data matrices with " << std::numeric_limits::max() << " features or greater"; num_feature = std::max(num_feature, static_cast(num_col)); @@ -652,6 +677,31 @@ class LearnerConfiguration : public Learner { p_metric->Configure(args); } } + + /** + * Get number of targets from objective function. + */ + ObjInfo ConfigureTargets() { + CHECK(this->obj_); + auto const& cache = this->GetPredictionCache()->Container(); + size_t n_targets = 1; + for (auto const& d : cache) { + if (n_targets == 1) { + n_targets = this->obj_->Targets(d.first->Info()); + } else { + auto t = this->obj_->Targets(d.first->Info()); + CHECK(n_targets == t || 1 == t) << "Inconsistent labels."; + } + } + if (mparam_.num_target != 1) { + CHECK(n_targets == 1 || n_targets == mparam_.num_target) + << "Inconsistent configuration of num_target. Configuration result from input data:" + << n_targets << ", configuration from parameter:" << mparam_.num_target; + } else { + mparam_.num_target = n_targets; + } + return this->obj_->Task(); + } }; std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT @@ -784,6 +834,9 @@ class LearnerIO : public LearnerConfiguration { if (!DMLC_IO_NO_ENDIAN_SWAP) { mparam_ = mparam_.ByteSwap(); } + if (mparam_.num_target == 0) { + mparam_.num_target = 1; + } CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format"; CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format"; diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 9dc84da98..abf888e0b 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -37,20 +37,26 @@ class ElementWiseMetricsReduction { PackedReduceResult CpuReduceMetrics(const HostDeviceVector &weights, - const HostDeviceVector &labels, + linalg::TensorView labels, const HostDeviceVector &preds, int32_t n_threads) const { size_t ndata = labels.Size(); + auto n_targets = std::max(labels.Shape(1), static_cast(1)); + auto h_labels = labels.Values(); - const auto& h_labels = labels.HostVector(); const auto& h_weights = weights.HostVector(); const auto& h_preds = preds.HostVector(); std::vector score_tloc(n_threads, 0.0); std::vector weight_tloc(n_threads, 0.0); + // We sum over losses over all samples and targets instead of performing this for each + // target since the first one approach more accurate while the second approach is used + // for approximation in distributed setting. For rmse: + // - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target + // - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed common::ParallelFor(ndata, n_threads, [&](size_t i) { - float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f; + float wt = h_weights.size() > 0 ? h_weights[i / n_targets] : 1.0f; auto t_idx = omp_get_thread_num(); score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; weight_tloc[t_idx] += wt; @@ -66,14 +72,15 @@ class ElementWiseMetricsReduction { PackedReduceResult DeviceReduceMetrics( const HostDeviceVector& weights, - const HostDeviceVector& labels, + linalg::TensorView labels, const HostDeviceVector& preds) { size_t n_data = preds.Size(); + auto n_targets = std::max(labels.Shape(1), static_cast(1)); thrust::counting_iterator begin(0); thrust::counting_iterator end = begin + n_data; - auto s_label = labels.DeviceSpan(); + auto s_label = labels.Values(); auto s_preds = preds.DeviceSpan(); auto s_weights = weights.DeviceSpan(); @@ -86,7 +93,7 @@ class ElementWiseMetricsReduction { thrust::cuda::par(alloc), begin, end, [=] XGBOOST_DEVICE(size_t idx) { - float weight = is_null_weight ? 1.0f : s_weights[idx]; + float weight = is_null_weight ? 1.0f : s_weights[idx / n_targets]; float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); residue *= weight; @@ -100,26 +107,22 @@ class ElementWiseMetricsReduction { #endif // XGBOOST_USE_CUDA - PackedReduceResult Reduce( - const GenericParameter &ctx, - const HostDeviceVector& weights, - const HostDeviceVector& labels, - const HostDeviceVector& preds) { + PackedReduceResult Reduce(const GenericParameter& ctx, const HostDeviceVector& weights, + linalg::Tensor const& labels, + const HostDeviceVector& preds) { PackedReduceResult result; if (ctx.gpu_id < 0) { auto n_threads = ctx.Threads(); - result = CpuReduceMetrics(weights, labels, preds, n_threads); + result = CpuReduceMetrics(weights, labels.HostView(), preds, n_threads); } #if defined(XGBOOST_USE_CUDA) else { // NOLINT - device_ = ctx.gpu_id; - preds.SetDevice(device_); - labels.SetDevice(device_); - weights.SetDevice(device_); + preds.SetDevice(ctx.gpu_id); + weights.SetDevice(ctx.gpu_id); - dh::safe_cuda(cudaSetDevice(device_)); - result = DeviceReduceMetrics(weights, labels, preds); + dh::safe_cuda(cudaSetDevice(ctx.gpu_id)); + result = DeviceReduceMetrics(weights, labels.View(ctx.gpu_id), preds); } #endif // defined(XGBOOST_USE_CUDA) return result; @@ -128,7 +131,6 @@ class ElementWiseMetricsReduction { private: EvalRow policy_; #if defined(XGBOOST_USE_CUDA) - int device_{-1}; #endif // defined(XGBOOST_USE_CUDA) }; @@ -364,7 +366,7 @@ struct EvalEWiseBase : public Metric { CHECK_EQ(preds.Size(), info.labels.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - auto result = reducer_.Reduce(*tparam_, info.weights_, *info.labels.Data(), preds); + auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels, preds); double dat[2] { result.Residue(), result.Weights() }; diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 5dd1a82dd..63a3f881e 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -56,6 +56,11 @@ class RegLossObj : public ObjFunction { return Loss::Info(); } + uint32_t Targets(MetaInfo const& info) const override { + // Multi-target regression. + return std::max(static_cast(1), info.labels.Shape(1)); + } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector* out_gpair) override { @@ -70,7 +75,7 @@ class RegLossObj : public ObjFunction { bool is_null_weight = info.weights_.Size() == 0; if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), ndata) + CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) << "Number of weights should be equal to number of data points."; } auto scale_pos_weight = param_.scale_pos_weight; @@ -83,8 +88,10 @@ class RegLossObj : public ObjFunction { // for better performance. const size_t n_data_blocks = std::max(static_cast(1), (on_device ? ndata : nthreads)); const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks); + auto const n_targets = std::max(info.labels.Shape(1), static_cast(1)); + common::Transform<>::Init( - [block_size, ndata] XGBOOST_DEVICE( + [block_size, ndata, n_targets] XGBOOST_DEVICE( size_t data_block_idx, common::Span _additional_input, common::Span _out_gpair, common::Span _preds, @@ -101,7 +108,7 @@ class RegLossObj : public ObjFunction { for (size_t idx = begin; idx < end; ++idx) { bst_float p = Loss::PredTransform(preds_ptr[idx]); - bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx]; + bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx / n_targets]; bst_float label = labels_ptr[idx]; if (label == 1.0f) { w *= _scale_pos_weight; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 810d39710..319aafba0 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -92,6 +92,7 @@ TEST(CAPI, ConfigIO) { labels[i] = i; } p_dmat->Info().labels.Data()->HostVector() = labels; + p_dmat->Info().labels.Reshape(kRows); std::shared_ptr learner { Learner::Create(mat) }; @@ -126,6 +127,7 @@ TEST(CAPI, JsonModelIO) { labels[i] = i; } p_dmat->Info().labels.Data()->HostVector() = labels; + p_dmat->Info().labels.Reshape(kRows); std::shared_ptr learner { Learner::Create(mat) }; diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index 04bb2c9e7..f070e6f81 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -9,8 +9,9 @@ #include #include -#include "../../../src/data/array_interface.h" + #include "../../../src/common/linalg_op.h" +#include "../../../src/data/array_interface.h" namespace xgboost { inline void TestMetaInfoStridedData(int32_t device) { diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index da627cdd1..3a74197d8 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -144,15 +144,26 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess); } -xgboost::bst_float GetMetricEval(xgboost::Metric * metric, +xgboost::bst_float GetMetricEval(xgboost::Metric* metric, xgboost::HostDeviceVector const& preds, std::vector labels, std::vector weights, std::vector groups) { + return GetMultiMetricEval( + metric, preds, + xgboost::linalg::Tensor{labels.begin(), labels.end(), {labels.size()}, -1}, weights, + groups); +} + +double GetMultiMetricEval(xgboost::Metric* metric, + xgboost::HostDeviceVector const& preds, + xgboost::linalg::Tensor const& labels, + std::vector weights, + std::vector groups) { xgboost::MetaInfo info; - info.num_row_ = labels.size(); - info.labels = - xgboost::linalg::Tensor{labels.begin(), labels.end(), {labels.size()}, -1}; + info.num_row_ = labels.Shape(0); + info.labels.Reshape(labels.Shape()[0], labels.Shape()[1]); + info.labels.Data()->Copy(*labels.Data()); info.weights_.HostVector() = weights; info.group_ptr_ = groups; @@ -344,13 +355,14 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, RandomDataGenerator gen(rows_, 1, 0); if (!float_label) { gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data()); - out->Info().labels.Reshape(out->Info().labels.Size()); + out->Info().labels.Reshape(this->rows_); auto& h_labels = out->Info().labels.Data()->HostVector(); for (auto& v : h_labels) { v = static_cast(static_cast(v)); } } else { gen.GenerateDense(out->Info().labels.Data()); + out->Info().labels.Reshape(this->rows_); } } if (device_ >= 0) { diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index c89bb3f45..f8b3f5874 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -91,6 +91,12 @@ xgboost::bst_float GetMetricEval( std::vector weights = std::vector(), std::vector groups = std::vector()); +double GetMultiMetricEval(xgboost::Metric* metric, + xgboost::HostDeviceVector const& preds, + xgboost::linalg::Tensor const& labels, + std::vector weights = {}, + std::vector groups = {}); + namespace xgboost { bool IsNear(std::vector::const_iterator _beg1, std::vector::const_iterator _end1, diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index d5d460a68..514b8753c 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -40,6 +40,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) } // anonymous namespace } // namespace xgboost +namespace xgboost { +namespace metric { + TEST(Metric, DeclareUnifiedTest(RMSE)) { auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam); @@ -276,3 +279,27 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); } + +TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { + size_t n_samples = 32, n_targets = 8; + linalg::Tensor y{{n_samples, n_targets}, GPUIDX}; + auto &h_y = y.Data()->HostVector(); + std::iota(h_y.begin(), h_y.end(), 0); + + HostDeviceVector predt(n_samples * n_targets, 0); + + auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); + std::unique_ptr metric{Metric::Create("rmse", &lparam)}; + metric->Configure({}); + + auto loss = GetMultiMetricEval(metric.get(), predt, y); + std::vector weights(n_samples, 1); + auto loss_w = GetMultiMetricEval(metric.get(), predt, y, weights); + + std::transform(h_y.cbegin(), h_y.cend(), h_y.begin(), [](auto &v) { return v * v; }); + auto ret = std::sqrt(std::accumulate(h_y.cbegin(), h_y.cend(), 1.0, std::plus<>{}) / h_y.size()); + ASSERT_FLOAT_EQ(ret, loss); + ASSERT_FLOAT_EQ(ret, loss_w); +} +} // namespace metric +} // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 23859bc28..6f49e6e8d 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -12,9 +12,9 @@ #include "xgboost/json.h" #include "../../src/common/io.h" #include "../../src/common/random.h" +#include "../../src/common/linalg_op.h" namespace xgboost { - TEST(Learner, Basic) { using Arg = std::pair; auto args = {Arg("tree_method", "exact")}; @@ -278,6 +278,7 @@ TEST(Learner, GPUConfiguration) { labels[i] = i; } p_dmat->Info().labels.Data()->HostVector() = labels; + p_dmat->Info().labels.Reshape(kRows); { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"booster", "gblinear"}, @@ -424,4 +425,28 @@ TEST(Learner, FeatureInfo) { ASSERT_TRUE(std::equal(out_types.begin(), out_types.end(), types.begin())); } } + +TEST(Learner, MultiTarget) { + size_t constexpr kRows{128}, kCols{10}, kTargets{3}; + auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); + m->Info().labels.Reshape(kRows, kTargets); + linalg::ElementWiseKernelHost(m->Info().labels.HostView(), omp_get_max_threads(), + [](auto i, auto) { return i; }); + + { + std::unique_ptr learner{Learner::Create({m})}; + learner->Configure(); + + Json model{Object()}; + learner->SaveModel(&model); + ASSERT_EQ(get(model["learner"]["learner_model_param"]["num_target"]), + std::to_string(kTargets)); + } + { + std::unique_ptr learner{Learner::Create({m})}; + learner->SetParam("objective", "multi:softprob"); + // unsupported objective. + EXPECT_THROW({ learner->Configure(); }, dmlc::Error); + } +} } // namespace xgboost diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index defc2b219..dc474f15e 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -60,8 +60,9 @@ def _test_from_cudf(DMatrixT): assert dtrain.feature_names == ['x'] assert dtrain.feature_types == ['int'] - with pytest.raises(Exception): + with pytest.raises(ValueError, match=r".*multi.*"): dtrain = DMatrixT(cd, label=cd) + xgb.train({"tree_method": "gpu_hist", "objective": "multi:softprob"}, dtrain) # Test when number of elements is less than 8 X = cudf.DataFrame({'x': cudf.Series([0, 1, 2, np.NAN, 4], diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index cbed40777..77fa694e5 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -50,9 +50,10 @@ def _test_from_cupy(DMatrixT): dmatrix_from_cupy(np.int32, DMatrixT, -2) dmatrix_from_cupy(np.int64, DMatrixT, -3) - with pytest.raises(Exception): + with pytest.raises(ValueError): X = cp.random.randn(2, 2, dtype="float32") - DMatrixT(X, label=X) + y = cp.random.randn(2, 2, 3, dtype="float32") + DMatrixT(X, label=y) def _test_cupy_training(DMatrixT): diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 63ba4f94c..119a02e57 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -277,7 +277,9 @@ def run_gpu_hist( X = to_cp(dataset.X, DMatrixT) X = da.from_array(X, chunks=(chunk, dataset.X.shape[1])) y = to_cp(dataset.y, DMatrixT) - y = da.from_array(y, chunks=(chunk,)) + y_chunk = chunk if len(dataset.y.shape) == 1 else (chunk, dataset.y.shape[1]) + y = da.from_array(y, chunks=y_chunk) + if dataset.w is not None: w = to_cp(dataset.w, DMatrixT) w = da.from_array(w, chunks=(chunk,)) diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index 17a2ccf83..ae359f2f1 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -52,8 +52,12 @@ def test_boost_from_prediction_gpu_hist(): X, y = load_digits(return_X_y=True) X, y = cp.array(X), cp.array(y) - twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, None) - twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, cudf.DataFrame) + twskl.run_boost_from_prediction_multi_clasas( + xgb.XGBClassifier, tree_method, X, y, None + ) + twskl.run_boost_from_prediction_multi_clasas( + xgb.XGBClassifier, tree_method, X, y, cudf.DataFrame + ) def test_num_parallel_tree(): diff --git a/tests/python/test_demos.py b/tests/python/test_demos.py index e4d1b804c..7b5b1b19a 100644 --- a/tests/python/test_demos.py +++ b/tests/python/test_demos.py @@ -127,6 +127,14 @@ def test_continuation_demo(): subprocess.check_call(cmd) +@pytest.mark.skipif(**tm.no_sklearn()) +@pytest.mark.skipif(**tm.no_matplotlib()) +def test_multioutput_reg() -> None: + script = os.path.join(PYTHON_DEMO_DIR, "multioutput_regression.py") + cmd = ['python', script, "--plot=0"] + subprocess.check_call(cmd) + + # gpu_acceleration is not tested due to covertype dataset is being too huge. # gamma regression is not tested as it requires running a R script first. # aft viz is not tested due to ploting is not controled diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 68f3d8eff..c03687b16 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1114,9 +1114,9 @@ class TestWithDask: return chunk = 128 - X = da.from_array(dataset.X, - chunks=(chunk, dataset.X.shape[1])) - y = da.from_array(dataset.y, chunks=(chunk,)) + y_chunk = chunk if len(dataset.y.shape) == 1 else (chunk, dataset.y.shape[1]) + X = da.from_array(dataset.X, chunks=(chunk, dataset.X.shape[1])) + y = da.from_array(dataset.y, chunks=y_chunk) if dataset.w is not None: w = da.from_array(dataset.w, chunks=(chunk,)) else: diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 4ab86b7e2..a5c0d8fe2 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1118,10 +1118,10 @@ def run_boost_from_prediction_binary(tree_method, X, y, as_frame: Optional[Calla def run_boost_from_prediction_multi_clasas( - tree_method, X, y, as_frame: Optional[Callable] + estimator, tree_method, X, y, as_frame: Optional[Callable] ): # Multi-class - model_0 = xgb.XGBClassifier( + model_0 = estimator( learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method ) model_0.fit(X=X, y=y) @@ -1129,7 +1129,7 @@ def run_boost_from_prediction_multi_clasas( if as_frame is not None: margin = as_frame(margin) - model_1 = xgb.XGBClassifier( + model_1 = estimator( learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method ) model_1.fit(X=X, y=y, base_margin=margin) @@ -1137,7 +1137,7 @@ def run_boost_from_prediction_multi_clasas( xgb.DMatrix(X, base_margin=margin), output_margin=True ) - model_2 = xgb.XGBClassifier( + model_2 = estimator( learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method ) model_2.fit(X=X, y=y) @@ -1152,8 +1152,9 @@ def run_boost_from_prediction_multi_clasas( @pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"]) def test_boost_from_prediction(tree_method): - from sklearn.datasets import load_breast_cancer, load_digits + from sklearn.datasets import load_breast_cancer, load_digits, make_regression import pandas as pd + X, y = load_breast_cancer(return_X_y=True) run_boost_from_prediction_binary(tree_method, X, y, None) @@ -1161,8 +1162,13 @@ def test_boost_from_prediction(tree_method): X, y = load_digits(return_X_y=True) - run_boost_from_prediction_multi_clasas(tree_method, X, y, None) - run_boost_from_prediction_multi_clasas(tree_method, X, y, pd.DataFrame) + run_boost_from_prediction_multi_clasas(xgb.XGBClassifier, tree_method, X, y, None) + run_boost_from_prediction_multi_clasas( + xgb.XGBClassifier, tree_method, X, y, pd.DataFrame + ) + + X, y = make_regression(n_samples=100, n_targets=4) + run_boost_from_prediction_multi_clasas(xgb.XGBRegressor, tree_method, X, y, None) def test_estimator_type(): diff --git a/tests/python/testing.py b/tests/python/testing.py index 2d0886079..cff0e96d5 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -305,26 +305,48 @@ def make_categorical( _unweighted_datasets_strategy = strategies.sampled_from( - [TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'), - TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'), - TestDataset("cancer", get_cancer, "binary:logistic", "logloss"), - TestDataset - ("sparse", get_sparse, "reg:squarederror", "rmse"), - TestDataset("empty", lambda: (np.empty((0, 100)), np.empty(0)), "reg:squarederror", - "rmse")]) + [ + TestDataset("boston", get_boston, "reg:squarederror", "rmse"), + TestDataset("digits", get_digits, "multi:softmax", "mlogloss"), + TestDataset("cancer", get_cancer, "binary:logistic", "logloss"), + TestDataset( + "mtreg", + lambda: datasets.make_regression(n_samples=128, n_targets=3), + "reg:squarederror", + "rmse", + ), + TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"), + TestDataset( + "empty", + lambda: (np.empty((0, 100)), np.empty(0)), + "reg:squarederror", + "rmse", + ), + ] +) @strategies.composite def _dataset_weight_margin(draw): data: TestDataset = draw(_unweighted_datasets_strategy) if draw(strategies.booleans()): - data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))) + data.w = draw( + arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)) + ) if draw(strategies.booleans()): num_class = 1 if data.objective == "multi:softmax": num_class = int(np.max(data.y) + 1) + elif data.name == "mtreg": + num_class = data.y.shape[1] + data.margin = draw( - arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0))) + arrays( + np.float64, + (data.y.shape[0] * num_class), + elements=strategies.floats(0.5, 1.0), + ) + ) if num_class != 1: data.margin = data.margin.reshape(data.y.shape[0], num_class)