Add check for length of weights. (#4872)

This commit is contained in:
Jiaming Yuan 2019-12-21 11:30:58 +08:00 committed by GitHub
parent 3d04a8cc97
commit c8bdb652c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 3 deletions

View File

@ -37,8 +37,12 @@ class HingeObj : public ObjFunction {
<< "preds.size=" << preds.Size() << "preds.size=" << preds.Size()
<< ", label.size=" << info.labels_.Size(); << ", label.size=" << info.labels_.Size();
const bool is_null_weight = info.weights_.Size() == 0;
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,

View File

@ -73,6 +73,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
label_correct_.Fill(1); label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t idx, [=] XGBOOST_DEVICE(size_t idx,
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,

View File

@ -60,13 +60,17 @@ class RegLossObj : public ObjFunction {
CHECK_EQ(preds.Size(), info.labels_.Size()) CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided" << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
size_t ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto device = tparam_->gpu_id; auto device = tparam_->gpu_id;
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight; auto scale_pos_weight = param_.scale_pos_weight;
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,
@ -188,13 +192,17 @@ class PoissonRegression : public ObjFunction {
HostDeviceVector<GradientPair> *out_gpair) override { HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
size_t ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto device = tparam_->gpu_id; auto device = tparam_->gpu_id;
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
bst_float max_delta_step = param_.max_delta_step; bst_float max_delta_step = param_.max_delta_step;
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,
@ -282,6 +290,11 @@ class CoxRegression : public ObjFunction {
const std::vector<size_t> &label_order = info.LabelAbsSort(); const std::vector<size_t> &label_order = info.LabelAbsSort();
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*) const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
// pre-compute a sum // pre-compute a sum
double exp_p_sum = 0; // we use double because we might need the precision with large datasets double exp_p_sum = 0; // we use double because we might need the precision with large datasets
@ -377,6 +390,10 @@ class GammaRegression : public ObjFunction {
label_correct_.Fill(1); label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct, common::Span<int> _label_correct,
@ -471,6 +488,11 @@ class TweedieRegression : public ObjFunction {
label_correct_.Fill(1); label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
const float rho = param_.tweedie_variance_power; const float rho = param_.tweedie_variance_power;
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,

View File

@ -25,6 +25,13 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
{0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad
{0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess
CheckObjFunction(obj,
{1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds
{1.0f, 0.0f}, // labels
{}, // weights
{0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad
{0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess
ASSERT_NO_THROW(obj->DefaultEvalMetric()); ASSERT_NO_THROW(obj->DefaultEvalMetric());
} }