Add check for length of weights. (#4872)
This commit is contained in:
parent
3d04a8cc97
commit
c8bdb652c4
@ -37,8 +37,12 @@ class HingeObj : public ObjFunction {
|
||||
<< "preds.size=" << preds.Size()
|
||||
<< ", label.size=" << info.labels_.Size();
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
const size_t ndata = preds.Size();
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
out_gpair->Resize(ndata);
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
|
||||
@ -73,6 +73,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t idx,
|
||||
common::Span<GradientPair> gpair,
|
||||
|
||||
@ -60,13 +60,17 @@ class RegLossObj : public ObjFunction {
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
|
||||
size_t ndata = preds.Size();
|
||||
size_t const ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
auto device = tparam_->gpu_id;
|
||||
label_correct_.Resize(1);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
@ -188,13 +192,17 @@ class PoissonRegression : public ObjFunction {
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
size_t ndata = preds.Size();
|
||||
size_t const ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
auto device = tparam_->gpu_id;
|
||||
label_correct_.Resize(1);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
bst_float max_delta_step = param_.max_delta_step;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
@ -282,6 +290,11 @@ class CoxRegression : public ObjFunction {
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort();
|
||||
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
// pre-compute a sum
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
@ -377,6 +390,10 @@ class GammaRegression : public ObjFunction {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
@ -471,6 +488,11 @@ class TweedieRegression : public ObjFunction {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
const float rho = param_.tweedie_variance_power;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
|
||||
@ -25,6 +25,13 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
|
||||
{0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad
|
||||
{0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess
|
||||
|
||||
CheckObjFunction(obj,
|
||||
{1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds
|
||||
{1.0f, 0.0f}, // labels
|
||||
{}, // weights
|
||||
{0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad
|
||||
{0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user