Initial support for multioutput regression. (#7514)
* Add num target model parameter, which is configured from input labels. * Change elementwise metric and indexing for weights. * Add demo. * Add tests.
This commit is contained in:
parent
9ab73f737e
commit
58a6723eb1
48
demo/guide-python/multioutput_regression.py
Normal file
48
demo/guide-python/multioutput_regression.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""
|
||||
A demo for multi-output regression
|
||||
==================================
|
||||
|
||||
The demo is adopted from scikit-learn:
|
||||
|
||||
https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py
|
||||
"""
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import argparse
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def plot_predt(y, y_predt, name):
|
||||
s = 25
|
||||
plt.scatter(y[:, 0], y[:, 1], c="navy", s=s,
|
||||
edgecolor="black", label="data")
|
||||
plt.scatter(y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s,
|
||||
edgecolor="black")
|
||||
plt.xlim([-1, 2])
|
||||
plt.ylim([-1, 2])
|
||||
plt.show()
|
||||
|
||||
|
||||
def main(plot_result: bool):
|
||||
"""Draw a circle with 2-dim coordinate as target variables."""
|
||||
rng = np.random.RandomState(1994)
|
||||
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
|
||||
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
|
||||
y[::5, :] += (0.5 - rng.rand(20, 2))
|
||||
y = y - y.min()
|
||||
y = y / y.max()
|
||||
|
||||
# Train a regressor on it
|
||||
reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64)
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
|
||||
y_predt = reg.predict(X)
|
||||
if plot_result:
|
||||
plot_predt(y, y_predt, 'multi')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
main(args.plot == 1)
|
||||
@ -77,6 +77,16 @@ class ObjFunction : public Configurable {
|
||||
* \brief Return task of this objective.
|
||||
*/
|
||||
virtual struct ObjInfo Task() const = 0;
|
||||
/**
|
||||
* \brief Return number of targets for input matrix. Right now XGBoost supports only
|
||||
* multi-target regression.
|
||||
*/
|
||||
virtual uint32_t Targets(MetaInfo const& info) const {
|
||||
if (info.labels.Shape(1) > 1) {
|
||||
LOG(FATAL) << "multioutput is not supported by current objective function";
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Create an objective function according to name.
|
||||
|
||||
@ -21,7 +21,7 @@ CAT_T = "c"
|
||||
# meta info that can be a matrix instead of vector.
|
||||
# For now it's base_margin for multi-class, but it can be extended to label once we have
|
||||
# multi-output.
|
||||
_matrix_meta = {"base_margin"}
|
||||
_matrix_meta = {"base_margin", "label"}
|
||||
|
||||
|
||||
def _warn_unused_missing(data, missing):
|
||||
|
||||
@ -160,10 +160,11 @@ void LoadTensorField(dmlc::Stream* strm, std::string const& expected_name,
|
||||
CHECK(strm->Read(&is_scalar)) << invalid;
|
||||
CHECK(!is_scalar) << invalid << "Expected field " << expected_name
|
||||
<< " to be a tensor; got a scalar";
|
||||
std::array<size_t, D> shape;
|
||||
size_t shape[D];
|
||||
for (size_t i = 0; i < D; ++i) {
|
||||
CHECK(strm->Read(&(shape[i])));
|
||||
}
|
||||
p_out->Reshape(shape);
|
||||
auto& field = p_out->Data()->HostVector();
|
||||
CHECK(strm->Read(&field)) << invalid;
|
||||
}
|
||||
@ -411,6 +412,7 @@ template <int32_t D, typename T>
|
||||
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
ArrayInterface<D> array{arr_interface};
|
||||
if (array.n == 0) {
|
||||
p_out->Reshape(array.shape);
|
||||
return;
|
||||
}
|
||||
CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value.";
|
||||
@ -737,8 +739,7 @@ void MetaInfo::Validate(int32_t device) const {
|
||||
return;
|
||||
}
|
||||
if (labels.Size() != 0) {
|
||||
CHECK_EQ(labels.Size(), num_row_)
|
||||
<< "Size of labels must equal to number of rows.";
|
||||
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
|
||||
check_device(*labels.Data());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
ArrayInterface<D> array(arr_interface);
|
||||
if (array.n == 0) {
|
||||
p_out->SetDevice(0);
|
||||
p_out->Reshape(array.shape);
|
||||
return;
|
||||
}
|
||||
CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value.";
|
||||
|
||||
@ -88,12 +88,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
/*! \brief the version of XGBoost. */
|
||||
uint32_t major_version;
|
||||
uint32_t minor_version;
|
||||
|
||||
uint32_t num_target{1};
|
||||
/*! \brief reserved field */
|
||||
int reserved[27];
|
||||
int reserved[26];
|
||||
/*! \brief constructor */
|
||||
LearnerModelParamLegacy() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
|
||||
base_score = 0.5f;
|
||||
num_target = 1;
|
||||
major_version = std::get<0>(Version::Self());
|
||||
minor_version = std::get<1>(Version::Self());
|
||||
static_assert(sizeof(LearnerModelParamLegacy) == 136,
|
||||
@ -119,6 +122,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
CHECK(ret.ec == std::errc());
|
||||
obj["num_class"] =
|
||||
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};
|
||||
|
||||
ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
|
||||
static_cast<int64_t>(num_target));
|
||||
obj["num_target"] =
|
||||
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};
|
||||
|
||||
return Json(std::move(obj));
|
||||
}
|
||||
void FromJson(Json const& obj) {
|
||||
@ -126,6 +135,11 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
std::map<std::string, std::string> m;
|
||||
m["num_feature"] = get<String const>(j_param.at("num_feature"));
|
||||
m["num_class"] = get<String const>(j_param.at("num_class"));
|
||||
auto n_targets_it = j_param.find("num_target");
|
||||
if (n_targets_it != j_param.cend()) {
|
||||
m["num_target"] = get<String const>(n_targets_it->second);
|
||||
}
|
||||
|
||||
this->Init(m);
|
||||
std::string str = get<String const>(j_param.at("base_score"));
|
||||
from_chars(str.c_str(), str.c_str() + str.size(), base_score);
|
||||
@ -139,6 +153,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
dmlc::ByteSwap(&x.contain_eval_metrics, sizeof(x.contain_eval_metrics), 1);
|
||||
dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1);
|
||||
dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1);
|
||||
dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1);
|
||||
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
|
||||
return x;
|
||||
}
|
||||
@ -156,15 +171,24 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe(
|
||||
"Number of class option for multi-class classifier. "
|
||||
" By default equals 0 and corresponds to binary classifier.");
|
||||
DMLC_DECLARE_FIELD(num_target)
|
||||
.set_default(1)
|
||||
.set_lower_bound(1)
|
||||
.describe("Number of target for multi-target regression.");
|
||||
}
|
||||
};
|
||||
|
||||
LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin,
|
||||
ObjInfo t)
|
||||
: base_score{base_margin},
|
||||
num_feature{user_param.num_feature},
|
||||
num_output_group{user_param.num_class == 0 ? 1 : static_cast<uint32_t>(user_param.num_class)},
|
||||
task{t} {}
|
||||
: base_score{base_margin}, num_feature{user_param.num_feature}, task{t} {
|
||||
auto n_classes = std::max(static_cast<uint32_t>(user_param.num_class), 1u);
|
||||
auto n_targets = user_param.num_target;
|
||||
num_output_group = std::max(n_classes, n_targets);
|
||||
// For version < 1.6, n_targets == 0
|
||||
CHECK(n_classes <= 1 || n_targets <= 1)
|
||||
<< "Multi-class multi-output is not yet supported. n_classes:" << n_classes
|
||||
<< ", n_targets:" << n_targets;
|
||||
}
|
||||
|
||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
// data split mode, can be row, col, or none.
|
||||
@ -325,6 +349,8 @@ class LearnerConfiguration : public Learner {
|
||||
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
||||
this->ConfigureObjective(old_tparam, &args);
|
||||
|
||||
auto task = this->ConfigureTargets();
|
||||
|
||||
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
|
||||
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
|
||||
// keep the stability, we initialize it in binary LoadModel instead of configuration.
|
||||
@ -339,7 +365,7 @@ class LearnerConfiguration : public Learner {
|
||||
// - model is configured second time due to change of parameter
|
||||
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
|
||||
learner_model_param_ =
|
||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
|
||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), task);
|
||||
}
|
||||
|
||||
this->ConfigureGBM(old_tparam, args);
|
||||
@ -586,8 +612,7 @@ class LearnerConfiguration : public Learner {
|
||||
CHECK(matrix.first);
|
||||
CHECK(!matrix.second.ref.expired());
|
||||
const uint64_t num_col = matrix.first->Info().num_col_;
|
||||
CHECK_LE(num_col,
|
||||
static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||
@ -652,6 +677,31 @@ class LearnerConfiguration : public Learner {
|
||||
p_metric->Configure(args);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get number of targets from objective function.
|
||||
*/
|
||||
ObjInfo ConfigureTargets() {
|
||||
CHECK(this->obj_);
|
||||
auto const& cache = this->GetPredictionCache()->Container();
|
||||
size_t n_targets = 1;
|
||||
for (auto const& d : cache) {
|
||||
if (n_targets == 1) {
|
||||
n_targets = this->obj_->Targets(d.first->Info());
|
||||
} else {
|
||||
auto t = this->obj_->Targets(d.first->Info());
|
||||
CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
|
||||
}
|
||||
}
|
||||
if (mparam_.num_target != 1) {
|
||||
CHECK(n_targets == 1 || n_targets == mparam_.num_target)
|
||||
<< "Inconsistent configuration of num_target. Configuration result from input data:"
|
||||
<< n_targets << ", configuration from parameter:" << mparam_.num_target;
|
||||
} else {
|
||||
mparam_.num_target = n_targets;
|
||||
}
|
||||
return this->obj_->Task();
|
||||
}
|
||||
};
|
||||
|
||||
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
|
||||
@ -784,6 +834,9 @@ class LearnerIO : public LearnerConfiguration {
|
||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||
mparam_ = mparam_.ByteSwap();
|
||||
}
|
||||
if (mparam_.num_target == 0) {
|
||||
mparam_.num_target = 1;
|
||||
}
|
||||
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
|
||||
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
|
||||
|
||||
|
||||
@ -37,20 +37,26 @@ class ElementWiseMetricsReduction {
|
||||
|
||||
PackedReduceResult
|
||||
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
|
||||
const HostDeviceVector<bst_float> &labels,
|
||||
linalg::TensorView<float const, 2> labels,
|
||||
const HostDeviceVector<bst_float> &preds,
|
||||
int32_t n_threads) const {
|
||||
size_t ndata = labels.Size();
|
||||
auto n_targets = std::max(labels.Shape(1), static_cast<size_t>(1));
|
||||
auto h_labels = labels.Values();
|
||||
|
||||
const auto& h_labels = labels.HostVector();
|
||||
const auto& h_weights = weights.HostVector();
|
||||
const auto& h_preds = preds.HostVector();
|
||||
|
||||
std::vector<double> score_tloc(n_threads, 0.0);
|
||||
std::vector<double> weight_tloc(n_threads, 0.0);
|
||||
|
||||
// We sum over losses over all samples and targets instead of performing this for each
|
||||
// target since the first one approach more accurate while the second approach is used
|
||||
// for approximation in distributed setting. For rmse:
|
||||
// - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target
|
||||
// - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed
|
||||
common::ParallelFor(ndata, n_threads, [&](size_t i) {
|
||||
float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
|
||||
float wt = h_weights.size() > 0 ? h_weights[i / n_targets] : 1.0f;
|
||||
auto t_idx = omp_get_thread_num();
|
||||
score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
|
||||
weight_tloc[t_idx] += wt;
|
||||
@ -66,14 +72,15 @@ class ElementWiseMetricsReduction {
|
||||
|
||||
PackedReduceResult DeviceReduceMetrics(
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
linalg::TensorView<float const, 2> labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
size_t n_data = preds.Size();
|
||||
auto n_targets = std::max(labels.Shape(1), static_cast<size_t>(1));
|
||||
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + n_data;
|
||||
|
||||
auto s_label = labels.DeviceSpan();
|
||||
auto s_label = labels.Values();
|
||||
auto s_preds = preds.DeviceSpan();
|
||||
auto s_weights = weights.DeviceSpan();
|
||||
|
||||
@ -86,7 +93,7 @@ class ElementWiseMetricsReduction {
|
||||
thrust::cuda::par(alloc),
|
||||
begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t idx) {
|
||||
float weight = is_null_weight ? 1.0f : s_weights[idx];
|
||||
float weight = is_null_weight ? 1.0f : s_weights[idx / n_targets];
|
||||
|
||||
float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]);
|
||||
residue *= weight;
|
||||
@ -100,26 +107,22 @@ class ElementWiseMetricsReduction {
|
||||
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
PackedReduceResult Reduce(
|
||||
const GenericParameter &ctx,
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
PackedReduceResult Reduce(const GenericParameter& ctx, const HostDeviceVector<bst_float>& weights,
|
||||
linalg::Tensor<float, 2> const& labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
PackedReduceResult result;
|
||||
|
||||
if (ctx.gpu_id < 0) {
|
||||
auto n_threads = ctx.Threads();
|
||||
result = CpuReduceMetrics(weights, labels, preds, n_threads);
|
||||
result = CpuReduceMetrics(weights, labels.HostView(), preds, n_threads);
|
||||
}
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
else { // NOLINT
|
||||
device_ = ctx.gpu_id;
|
||||
preds.SetDevice(device_);
|
||||
labels.SetDevice(device_);
|
||||
weights.SetDevice(device_);
|
||||
preds.SetDevice(ctx.gpu_id);
|
||||
weights.SetDevice(ctx.gpu_id);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
result = DeviceReduceMetrics(weights, labels, preds);
|
||||
dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
|
||||
result = DeviceReduceMetrics(weights, labels.View(ctx.gpu_id), preds);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
return result;
|
||||
@ -128,7 +131,6 @@ class ElementWiseMetricsReduction {
|
||||
private:
|
||||
EvalRow policy_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
int device_{-1};
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
|
||||
@ -364,7 +366,7 @@ struct EvalEWiseBase : public Metric {
|
||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||
<< "label and prediction size not match, "
|
||||
<< "hint: use merror or mlogloss for multi-class classification";
|
||||
auto result = reducer_.Reduce(*tparam_, info.weights_, *info.labels.Data(), preds);
|
||||
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels, preds);
|
||||
|
||||
double dat[2] { result.Residue(), result.Weights() };
|
||||
|
||||
|
||||
@ -56,6 +56,11 @@ class RegLossObj : public ObjFunction {
|
||||
return Loss::Info();
|
||||
}
|
||||
|
||||
uint32_t Targets(MetaInfo const& info) const override {
|
||||
// Multi-target regression.
|
||||
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
@ -70,7 +75,7 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
@ -83,8 +88,10 @@ class RegLossObj : public ObjFunction {
|
||||
// for better performance.
|
||||
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
|
||||
const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks);
|
||||
auto const n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1));
|
||||
|
||||
common::Transform<>::Init(
|
||||
[block_size, ndata] XGBOOST_DEVICE(
|
||||
[block_size, ndata, n_targets] XGBOOST_DEVICE(
|
||||
size_t data_block_idx, common::Span<float> _additional_input,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
@ -101,7 +108,7 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
for (size_t idx = begin; idx < end; ++idx) {
|
||||
bst_float p = Loss::PredTransform(preds_ptr[idx]);
|
||||
bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx];
|
||||
bst_float w = _is_null_weight ? 1.0f : weights_ptr[idx / n_targets];
|
||||
bst_float label = labels_ptr[idx];
|
||||
if (label == 1.0f) {
|
||||
w *= _scale_pos_weight;
|
||||
|
||||
@ -92,6 +92,7 @@ TEST(CAPI, ConfigIO) {
|
||||
labels[i] = i;
|
||||
}
|
||||
p_dmat->Info().labels.Data()->HostVector() = labels;
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
|
||||
std::shared_ptr<Learner> learner { Learner::Create(mat) };
|
||||
|
||||
@ -126,6 +127,7 @@ TEST(CAPI, JsonModelIO) {
|
||||
labels[i] = i;
|
||||
}
|
||||
p_dmat->Info().labels.Data()->HostVector() = labels;
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
|
||||
std::shared_ptr<Learner> learner { Learner::Create(mat) };
|
||||
|
||||
|
||||
@ -9,8 +9,9 @@
|
||||
#include <xgboost/linalg.h>
|
||||
|
||||
#include <numeric>
|
||||
#include "../../../src/data/array_interface.h"
|
||||
|
||||
#include "../../../src/common/linalg_op.h"
|
||||
#include "../../../src/data/array_interface.h"
|
||||
|
||||
namespace xgboost {
|
||||
inline void TestMetaInfoStridedData(int32_t device) {
|
||||
|
||||
@ -144,15 +144,26 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||
CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess);
|
||||
}
|
||||
|
||||
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
||||
xgboost::bst_float GetMetricEval(xgboost::Metric* metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups) {
|
||||
return GetMultiMetricEval(
|
||||
metric, preds,
|
||||
xgboost::linalg::Tensor<float, 2>{labels.begin(), labels.end(), {labels.size()}, -1}, weights,
|
||||
groups);
|
||||
}
|
||||
|
||||
double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
xgboost::linalg::Tensor<float, 2> const& labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row_ = labels.size();
|
||||
info.labels =
|
||||
xgboost::linalg::Tensor<float, 2>{labels.begin(), labels.end(), {labels.size()}, -1};
|
||||
info.num_row_ = labels.Shape(0);
|
||||
info.labels.Reshape(labels.Shape()[0], labels.Shape()[1]);
|
||||
info.labels.Data()->Copy(*labels.Data());
|
||||
info.weights_.HostVector() = weights;
|
||||
info.group_ptr_ = groups;
|
||||
|
||||
@ -344,13 +355,14 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
||||
RandomDataGenerator gen(rows_, 1, 0);
|
||||
if (!float_label) {
|
||||
gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data());
|
||||
out->Info().labels.Reshape(out->Info().labels.Size());
|
||||
out->Info().labels.Reshape(this->rows_);
|
||||
auto& h_labels = out->Info().labels.Data()->HostVector();
|
||||
for (auto& v : h_labels) {
|
||||
v = static_cast<float>(static_cast<uint32_t>(v));
|
||||
}
|
||||
} else {
|
||||
gen.GenerateDense(out->Info().labels.Data());
|
||||
out->Info().labels.Reshape(this->rows_);
|
||||
}
|
||||
}
|
||||
if (device_ >= 0) {
|
||||
|
||||
@ -91,6 +91,12 @@ xgboost::bst_float GetMetricEval(
|
||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float>(),
|
||||
std::vector<xgboost::bst_uint> groups = std::vector<xgboost::bst_uint>());
|
||||
|
||||
double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
xgboost::linalg::Tensor<float, 2> const& labels,
|
||||
std::vector<xgboost::bst_float> weights = {},
|
||||
std::vector<xgboost::bst_uint> groups = {});
|
||||
|
||||
namespace xgboost {
|
||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
||||
|
||||
@ -40,6 +40,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
|
||||
} // anonymous namespace
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(RMSE)) {
|
||||
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam);
|
||||
@ -276,3 +279,27 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
|
||||
|
||||
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(MultiRMSE)) {
|
||||
size_t n_samples = 32, n_targets = 8;
|
||||
linalg::Tensor<float, 2> y{{n_samples, n_targets}, GPUIDX};
|
||||
auto &h_y = y.Data()->HostVector();
|
||||
std::iota(h_y.begin(), h_y.end(), 0);
|
||||
|
||||
HostDeviceVector<float> predt(n_samples * n_targets, 0);
|
||||
|
||||
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<Metric> metric{Metric::Create("rmse", &lparam)};
|
||||
metric->Configure({});
|
||||
|
||||
auto loss = GetMultiMetricEval(metric.get(), predt, y);
|
||||
std::vector<float> weights(n_samples, 1);
|
||||
auto loss_w = GetMultiMetricEval(metric.get(), predt, y, weights);
|
||||
|
||||
std::transform(h_y.cbegin(), h_y.cend(), h_y.begin(), [](auto &v) { return v * v; });
|
||||
auto ret = std::sqrt(std::accumulate(h_y.cbegin(), h_y.cend(), 1.0, std::plus<>{}) / h_y.size());
|
||||
ASSERT_FLOAT_EQ(ret, loss);
|
||||
ASSERT_FLOAT_EQ(ret, loss_w);
|
||||
}
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
@ -12,9 +12,9 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/random.h"
|
||||
#include "../../src/common/linalg_op.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(Learner, Basic) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
auto args = {Arg("tree_method", "exact")};
|
||||
@ -278,6 +278,7 @@ TEST(Learner, GPUConfiguration) {
|
||||
labels[i] = i;
|
||||
}
|
||||
p_dmat->Info().labels.Data()->HostVector() = labels;
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
{
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"booster", "gblinear"},
|
||||
@ -424,4 +425,28 @@ TEST(Learner, FeatureInfo) {
|
||||
ASSERT_TRUE(std::equal(out_types.begin(), out_types.end(), types.begin()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Learner, MultiTarget) {
|
||||
size_t constexpr kRows{128}, kCols{10}, kTargets{3};
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
m->Info().labels.Reshape(kRows, kTargets);
|
||||
linalg::ElementWiseKernelHost(m->Info().labels.HostView(), omp_get_max_threads(),
|
||||
[](auto i, auto) { return i; });
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->Configure();
|
||||
|
||||
Json model{Object()};
|
||||
learner->SaveModel(&model);
|
||||
ASSERT_EQ(get<String>(model["learner"]["learner_model_param"]["num_target"]),
|
||||
std::to_string(kTargets));
|
||||
}
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->SetParam("objective", "multi:softprob");
|
||||
// unsupported objective.
|
||||
EXPECT_THROW({ learner->Configure(); }, dmlc::Error);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -60,8 +60,9 @@ def _test_from_cudf(DMatrixT):
|
||||
assert dtrain.feature_names == ['x']
|
||||
assert dtrain.feature_types == ['int']
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(ValueError, match=r".*multi.*"):
|
||||
dtrain = DMatrixT(cd, label=cd)
|
||||
xgb.train({"tree_method": "gpu_hist", "objective": "multi:softprob"}, dtrain)
|
||||
|
||||
# Test when number of elements is less than 8
|
||||
X = cudf.DataFrame({'x': cudf.Series([0, 1, 2, np.NAN, 4],
|
||||
|
||||
@ -50,9 +50,10 @@ def _test_from_cupy(DMatrixT):
|
||||
dmatrix_from_cupy(np.int32, DMatrixT, -2)
|
||||
dmatrix_from_cupy(np.int64, DMatrixT, -3)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(ValueError):
|
||||
X = cp.random.randn(2, 2, dtype="float32")
|
||||
DMatrixT(X, label=X)
|
||||
y = cp.random.randn(2, 2, 3, dtype="float32")
|
||||
DMatrixT(X, label=y)
|
||||
|
||||
|
||||
def _test_cupy_training(DMatrixT):
|
||||
|
||||
@ -277,7 +277,9 @@ def run_gpu_hist(
|
||||
X = to_cp(dataset.X, DMatrixT)
|
||||
X = da.from_array(X, chunks=(chunk, dataset.X.shape[1]))
|
||||
y = to_cp(dataset.y, DMatrixT)
|
||||
y = da.from_array(y, chunks=(chunk,))
|
||||
y_chunk = chunk if len(dataset.y.shape) == 1 else (chunk, dataset.y.shape[1])
|
||||
y = da.from_array(y, chunks=y_chunk)
|
||||
|
||||
if dataset.w is not None:
|
||||
w = to_cp(dataset.w, DMatrixT)
|
||||
w = da.from_array(w, chunks=(chunk,))
|
||||
|
||||
@ -52,8 +52,12 @@ def test_boost_from_prediction_gpu_hist():
|
||||
X, y = load_digits(return_X_y=True)
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
|
||||
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
|
||||
twskl.run_boost_from_prediction_multi_clasas(tree_method, X, y, cudf.DataFrame)
|
||||
twskl.run_boost_from_prediction_multi_clasas(
|
||||
xgb.XGBClassifier, tree_method, X, y, None
|
||||
)
|
||||
twskl.run_boost_from_prediction_multi_clasas(
|
||||
xgb.XGBClassifier, tree_method, X, y, cudf.DataFrame
|
||||
)
|
||||
|
||||
|
||||
def test_num_parallel_tree():
|
||||
|
||||
@ -127,6 +127,14 @@ def test_continuation_demo():
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||
def test_multioutput_reg() -> None:
|
||||
script = os.path.join(PYTHON_DEMO_DIR, "multioutput_regression.py")
|
||||
cmd = ['python', script, "--plot=0"]
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
# gpu_acceleration is not tested due to covertype dataset is being too huge.
|
||||
# gamma regression is not tested as it requires running a R script first.
|
||||
# aft viz is not tested due to ploting is not controled
|
||||
|
||||
@ -1114,9 +1114,9 @@ class TestWithDask:
|
||||
return
|
||||
|
||||
chunk = 128
|
||||
X = da.from_array(dataset.X,
|
||||
chunks=(chunk, dataset.X.shape[1]))
|
||||
y = da.from_array(dataset.y, chunks=(chunk,))
|
||||
y_chunk = chunk if len(dataset.y.shape) == 1 else (chunk, dataset.y.shape[1])
|
||||
X = da.from_array(dataset.X, chunks=(chunk, dataset.X.shape[1]))
|
||||
y = da.from_array(dataset.y, chunks=y_chunk)
|
||||
if dataset.w is not None:
|
||||
w = da.from_array(dataset.w, chunks=(chunk,))
|
||||
else:
|
||||
|
||||
@ -1118,10 +1118,10 @@ def run_boost_from_prediction_binary(tree_method, X, y, as_frame: Optional[Calla
|
||||
|
||||
|
||||
def run_boost_from_prediction_multi_clasas(
|
||||
tree_method, X, y, as_frame: Optional[Callable]
|
||||
estimator, tree_method, X, y, as_frame: Optional[Callable]
|
||||
):
|
||||
# Multi-class
|
||||
model_0 = xgb.XGBClassifier(
|
||||
model_0 = estimator(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_0.fit(X=X, y=y)
|
||||
@ -1129,7 +1129,7 @@ def run_boost_from_prediction_multi_clasas(
|
||||
if as_frame is not None:
|
||||
margin = as_frame(margin)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
model_1 = estimator(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method
|
||||
)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
@ -1137,7 +1137,7 @@ def run_boost_from_prediction_multi_clasas(
|
||||
xgb.DMatrix(X, base_margin=margin), output_margin=True
|
||||
)
|
||||
|
||||
model_2 = xgb.XGBClassifier(
|
||||
model_2 = estimator(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8, tree_method=tree_method
|
||||
)
|
||||
model_2.fit(X=X, y=y)
|
||||
@ -1152,8 +1152,9 @@ def run_boost_from_prediction_multi_clasas(
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||
def test_boost_from_prediction(tree_method):
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
from sklearn.datasets import load_breast_cancer, load_digits, make_regression
|
||||
import pandas as pd
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
run_boost_from_prediction_binary(tree_method, X, y, None)
|
||||
@ -1161,8 +1162,13 @@ def test_boost_from_prediction(tree_method):
|
||||
|
||||
X, y = load_digits(return_X_y=True)
|
||||
|
||||
run_boost_from_prediction_multi_clasas(tree_method, X, y, None)
|
||||
run_boost_from_prediction_multi_clasas(tree_method, X, y, pd.DataFrame)
|
||||
run_boost_from_prediction_multi_clasas(xgb.XGBClassifier, tree_method, X, y, None)
|
||||
run_boost_from_prediction_multi_clasas(
|
||||
xgb.XGBClassifier, tree_method, X, y, pd.DataFrame
|
||||
)
|
||||
|
||||
X, y = make_regression(n_samples=100, n_targets=4)
|
||||
run_boost_from_prediction_multi_clasas(xgb.XGBRegressor, tree_method, X, y, None)
|
||||
|
||||
|
||||
def test_estimator_type():
|
||||
|
||||
@ -305,26 +305,48 @@ def make_categorical(
|
||||
|
||||
|
||||
_unweighted_datasets_strategy = strategies.sampled_from(
|
||||
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
|
||||
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),
|
||||
[
|
||||
TestDataset("boston", get_boston, "reg:squarederror", "rmse"),
|
||||
TestDataset("digits", get_digits, "multi:softmax", "mlogloss"),
|
||||
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
|
||||
TestDataset
|
||||
("sparse", get_sparse, "reg:squarederror", "rmse"),
|
||||
TestDataset("empty", lambda: (np.empty((0, 100)), np.empty(0)), "reg:squarederror",
|
||||
"rmse")])
|
||||
TestDataset(
|
||||
"mtreg",
|
||||
lambda: datasets.make_regression(n_samples=128, n_targets=3),
|
||||
"reg:squarederror",
|
||||
"rmse",
|
||||
),
|
||||
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
|
||||
TestDataset(
|
||||
"empty",
|
||||
lambda: (np.empty((0, 100)), np.empty(0)),
|
||||
"reg:squarederror",
|
||||
"rmse",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@strategies.composite
|
||||
def _dataset_weight_margin(draw):
|
||||
data: TestDataset = draw(_unweighted_datasets_strategy)
|
||||
if draw(strategies.booleans()):
|
||||
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
|
||||
data.w = draw(
|
||||
arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))
|
||||
)
|
||||
if draw(strategies.booleans()):
|
||||
num_class = 1
|
||||
if data.objective == "multi:softmax":
|
||||
num_class = int(np.max(data.y) + 1)
|
||||
elif data.name == "mtreg":
|
||||
num_class = data.y.shape[1]
|
||||
|
||||
data.margin = draw(
|
||||
arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0)))
|
||||
arrays(
|
||||
np.float64,
|
||||
(data.y.shape[0] * num_class),
|
||||
elements=strategies.floats(0.5, 1.0),
|
||||
)
|
||||
)
|
||||
if num_class != 1:
|
||||
data.margin = data.margin.reshape(data.y.shape[0], num_class)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user