Validate weights are positive values. (#6115)
This commit is contained in:
parent
c6f2b8c841
commit
b5f52f0b1b
@ -64,8 +64,8 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
|
|||||||
expect_true(setinfo(dtest, 'group', c(50, 50)))
|
expect_true(setinfo(dtest, 'group', c(50, 50)))
|
||||||
expect_error(setinfo(dtest, 'group', test_label))
|
expect_error(setinfo(dtest, 'group', test_label))
|
||||||
|
|
||||||
# providing character values will give a warning
|
# providing character values will give an error
|
||||||
expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data))))
|
expect_error(setinfo(dtest, 'weight', rep('a', nrow(test_data))))
|
||||||
|
|
||||||
# any other label should error
|
# any other label should error
|
||||||
expect_error(setinfo(dtest, 'asdf', test_label))
|
expect_error(setinfo(dtest, 'asdf', test_label))
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class IterForDMatrixDemo(xgboost.core.DataIter):
|
|||||||
rng = cupy.random.RandomState(1994)
|
rng = cupy.random.RandomState(1994)
|
||||||
self._data = [rng.randn(self.rows, self.cols)] * BATCHES
|
self._data = [rng.randn(self.rows, self.cols)] * BATCHES
|
||||||
self._labels = [rng.randn(self.rows)] * BATCHES
|
self._labels = [rng.randn(self.rows)] * BATCHES
|
||||||
self._weights = [rng.randn(self.rows)] * BATCHES
|
self._weights = [rng.uniform(size=self.rows)] * BATCHES
|
||||||
|
|
||||||
self.it = 0 # set iterator to 0
|
self.it = 0 # set iterator to 0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -359,6 +359,9 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
|||||||
weights.resize(num);
|
weights.resize(num);
|
||||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||||
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
||||||
|
auto valid = std::all_of(weights.cbegin(), weights.cend(),
|
||||||
|
[](float w) { return w >= 0; });
|
||||||
|
CHECK(valid) << "Weights must be positive values.";
|
||||||
} else if (!std::strcmp(key, "base_margin")) {
|
} else if (!std::strcmp(key, "base_margin")) {
|
||||||
auto& base_margin = base_margin_.HostVector();
|
auto& base_margin = base_margin_.HostVector();
|
||||||
base_margin.resize(num);
|
base_margin.resize(num);
|
||||||
|
|||||||
@ -86,6 +86,10 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
CopyInfoImpl(array_interface, &labels_);
|
CopyInfoImpl(array_interface, &labels_);
|
||||||
} else if (key == "weight") {
|
} else if (key == "weight") {
|
||||||
CopyInfoImpl(array_interface, &weights_);
|
CopyInfoImpl(array_interface, &weights_);
|
||||||
|
auto ptr = weights_.ConstDevicePointer();
|
||||||
|
auto valid =
|
||||||
|
thrust::all_of(thrust::device, ptr, ptr + weights_.Size(), AllOfOp{});
|
||||||
|
CHECK(valid) << "Weights must be positive values.";
|
||||||
} else if (key == "base_margin") {
|
} else if (key == "base_margin") {
|
||||||
CopyInfoImpl(array_interface, &base_margin_);
|
CopyInfoImpl(array_interface, &base_margin_);
|
||||||
} else if (key == "group") {
|
} else if (key == "group") {
|
||||||
@ -100,10 +104,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
} else if (key == "feature_weights") {
|
} else if (key == "feature_weights") {
|
||||||
CopyInfoImpl(array_interface, &feature_weigths);
|
CopyInfoImpl(array_interface, &feature_weigths);
|
||||||
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
||||||
auto valid =
|
auto valid = thrust::all_of(
|
||||||
thrust::all_of(thrust::device, d_feature_weights.data(),
|
thrust::device, d_feature_weights.data(),
|
||||||
d_feature_weights.data() + d_feature_weights.size(),
|
d_feature_weights.data() + d_feature_weights.size(), AllOfOp{});
|
||||||
AllOfOp{});
|
|
||||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -135,7 +135,7 @@ class TestModin(unittest.TestCase):
|
|||||||
|
|
||||||
X = np.random.randn(kRows, kCols)
|
X = np.random.randn(kRows, kCols)
|
||||||
y = np.random.randn(kRows)
|
y = np.random.randn(kRows)
|
||||||
w = np.random.randn(kRows).astype(np.float32)
|
w = np.random.uniform(size=kRows).astype(np.float32)
|
||||||
w_pd = md.DataFrame(w)
|
w_pd = md.DataFrame(w)
|
||||||
data = xgb.DMatrix(X, y, w_pd)
|
data = xgb.DMatrix(X, y, w_pd)
|
||||||
|
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class TestPandas(unittest.TestCase):
|
|||||||
|
|
||||||
X = np.random.randn(kRows, kCols)
|
X = np.random.randn(kRows, kCols)
|
||||||
y = np.random.randn(kRows)
|
y = np.random.randn(kRows)
|
||||||
w = np.random.randn(kRows).astype(np.float32)
|
w = np.random.uniform(size=kRows).astype(np.float32)
|
||||||
w_pd = pd.DataFrame(w)
|
w_pd = pd.DataFrame(w)
|
||||||
data = xgb.DMatrix(X, y, w_pd)
|
data = xgb.DMatrix(X, y, w_pd)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user