diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 9652670aa..b889cb81d 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -37,8 +37,12 @@ class HingeObj : public ObjFunction { << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); - const bool is_null_weight = info.weights_.Size() == 0; const size_t ndata = preds.Size(); + const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } out_gpair->Resize(ndata); common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 913e517fc..dce195629 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -73,6 +73,11 @@ class SoftmaxMultiClassObj : public ObjFunction { label_correct_.Fill(1); const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } + common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t idx, common::Span gpair, diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 0be272230..cd9b1e4ce 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -60,13 +60,17 @@ class RegLossObj : public ObjFunction { CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided" << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); - size_t ndata = preds.Size(); + size_t const ndata = preds.Size(); out_gpair->Resize(ndata); auto device = tparam_->gpu_id; label_correct_.Resize(1); label_correct_.Fill(1); bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } auto scale_pos_weight = param_.scale_pos_weight; common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, @@ -188,13 +192,17 @@ class PoissonRegression : public ObjFunction { HostDeviceVector *out_gpair) override { CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; - size_t ndata = preds.Size(); + size_t const ndata = preds.Size(); out_gpair->Resize(ndata); auto device = tparam_->gpu_id; label_correct_.Resize(1); label_correct_.Fill(1); bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } bst_float max_delta_step = param_.max_delta_step; common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, @@ -282,6 +290,11 @@ class CoxRegression : public ObjFunction { const std::vector &label_order = info.LabelAbsSort(); const omp_ulong ndata = static_cast(preds_h.size()); // NOLINT(*) + const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } // pre-compute a sum double exp_p_sum = 0; // we use double because we might need the precision with large datasets @@ -377,6 +390,10 @@ class GammaRegression : public ObjFunction { label_correct_.Fill(1); const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, @@ -471,6 +488,11 @@ class TweedieRegression : public ObjFunction { label_correct_.Fill(1); const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } + const float rho = param_.tweedie_variance_power; common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, diff --git a/tests/cpp/objective/test_multiclass_obj.cc b/tests/cpp/objective/test_multiclass_obj.cc index de53dee6a..30e06e977 100644 --- a/tests/cpp/objective/test_multiclass_obj.cc +++ b/tests/cpp/objective/test_multiclass_obj.cc @@ -25,6 +25,13 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess + CheckObjFunction(obj, + {1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds + {1.0f, 0.0f}, // labels + {}, // weights + {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad + {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess + ASSERT_NO_THROW(obj->DefaultEvalMetric()); }