Add check for length of weights. (#4872)
This commit is contained in:
@@ -37,8 +37,12 @@ class HingeObj : public ObjFunction {
|
||||
<< "preds.size=" << preds.Size()
|
||||
<< ", label.size=" << info.labels_.Size();
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
const size_t ndata = preds.Size();
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
out_gpair->Resize(ndata);
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
|
||||
@@ -73,6 +73,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t idx,
|
||||
common::Span<GradientPair> gpair,
|
||||
|
||||
@@ -60,13 +60,17 @@ class RegLossObj : public ObjFunction {
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
|
||||
size_t ndata = preds.Size();
|
||||
size_t const ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
auto device = tparam_->gpu_id;
|
||||
label_correct_.Resize(1);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
@@ -188,13 +192,17 @@ class PoissonRegression : public ObjFunction {
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
size_t ndata = preds.Size();
|
||||
size_t const ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
auto device = tparam_->gpu_id;
|
||||
label_correct_.Resize(1);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
bst_float max_delta_step = param_.max_delta_step;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
@@ -282,6 +290,11 @@ class CoxRegression : public ObjFunction {
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort();
|
||||
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
// pre-compute a sum
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
@@ -377,6 +390,10 @@ class GammaRegression : public ObjFunction {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
@@ -471,6 +488,11 @@ class TweedieRegression : public ObjFunction {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
const float rho = param_.tweedie_variance_power;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
|
||||
Reference in New Issue
Block a user