From b5f52f0b1b3acafc2f1e24ae8f2b6301a71f1761 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 15 Sep 2020 09:03:55 +0800 Subject: [PATCH] Validate weights are positive values. (#6115) --- R-package/tests/testthat/test_dmatrix.R | 4 ++-- demo/guide-python/data_iterator.py | 2 +- src/data/data.cc | 3 +++ src/data/data.cu | 11 +++++++---- tests/python/test_with_modin.py | 2 +- tests/python/test_with_pandas.py | 2 +- 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 91fa5aec0..fc6c4862f 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -64,8 +64,8 @@ test_that("xgb.DMatrix: getinfo & setinfo", { expect_true(setinfo(dtest, 'group', c(50, 50))) expect_error(setinfo(dtest, 'group', test_label)) - # providing character values will give a warning - expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data)))) + # providing character values will give an error + expect_error(setinfo(dtest, 'weight', rep('a', nrow(test_data)))) # any other label should error expect_error(setinfo(dtest, 'asdf', test_label)) diff --git a/demo/guide-python/data_iterator.py b/demo/guide-python/data_iterator.py index 4f4b08c0f..c7300d9f6 100644 --- a/demo/guide-python/data_iterator.py +++ b/demo/guide-python/data_iterator.py @@ -39,7 +39,7 @@ class IterForDMatrixDemo(xgboost.core.DataIter): rng = cupy.random.RandomState(1994) self._data = [rng.randn(self.rows, self.cols)] * BATCHES self._labels = [rng.randn(self.rows)] * BATCHES - self._weights = [rng.randn(self.rows)] * BATCHES + self._weights = [rng.uniform(size=self.rows)] * BATCHES self.it = 0 # set iterator to 0 super().__init__() diff --git a/src/data/data.cc b/src/data/data.cc index ad74008eb..5df5be56c 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -359,6 +359,9 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t weights.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, std::copy(cast_dptr, cast_dptr + num, weights.begin())); + auto valid = std::all_of(weights.cbegin(), weights.cend(), + [](float w) { return w >= 0; }); + CHECK(valid) << "Weights must be positive values."; } else if (!std::strcmp(key, "base_margin")) { auto& base_margin = base_margin_.HostVector(); base_margin.resize(num); diff --git a/src/data/data.cu b/src/data/data.cu index 152604987..f1eeb01de 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -86,6 +86,10 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { CopyInfoImpl(array_interface, &labels_); } else if (key == "weight") { CopyInfoImpl(array_interface, &weights_); + auto ptr = weights_.ConstDevicePointer(); + auto valid = + thrust::all_of(thrust::device, ptr, ptr + weights_.Size(), AllOfOp{}); + CHECK(valid) << "Weights must be positive values."; } else if (key == "base_margin") { CopyInfoImpl(array_interface, &base_margin_); } else if (key == "group") { @@ -100,10 +104,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { } else if (key == "feature_weights") { CopyInfoImpl(array_interface, &feature_weigths); auto d_feature_weights = feature_weigths.ConstDeviceSpan(); - auto valid = - thrust::all_of(thrust::device, d_feature_weights.data(), - d_feature_weights.data() + d_feature_weights.size(), - AllOfOp{}); + auto valid = thrust::all_of( + thrust::device, d_feature_weights.data(), + d_feature_weights.data() + d_feature_weights.size(), AllOfOp{}); CHECK(valid) << "Feature weight must be greater than 0."; return; } else { diff --git a/tests/python/test_with_modin.py b/tests/python/test_with_modin.py index d79672631..05cbaf881 100644 --- a/tests/python/test_with_modin.py +++ b/tests/python/test_with_modin.py @@ -135,7 +135,7 @@ class TestModin(unittest.TestCase): X = np.random.randn(kRows, kCols) y = np.random.randn(kRows) - w = np.random.randn(kRows).astype(np.float32) + w = np.random.uniform(size=kRows).astype(np.float32) w_pd = md.DataFrame(w) data = xgb.DMatrix(X, y, w_pd) diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 04f8c9510..56aa3e9f3 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -151,7 +151,7 @@ class TestPandas(unittest.TestCase): X = np.random.randn(kRows, kCols) y = np.random.randn(kRows) - w = np.random.randn(kRows).astype(np.float32) + w = np.random.uniform(size=kRows).astype(np.float32) w_pd = pd.DataFrame(w) data = xgb.DMatrix(X, y, w_pd)