Pass scikit learn estimator checks for regressor. (#7130)

* Check data shape.
* Check labels.
This commit is contained in:
Jiaming Yuan
2021-08-03 18:58:20 +08:00
committed by GitHub
parent 8ee127469f
commit 8a84be37b8
7 changed files with 103 additions and 39 deletions

View File

@@ -360,13 +360,18 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) {
return std::isnan(y) || std::isinf(y);
});
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
} else if (!std::strcmp(key, "weight")) {
auto& weights = weights_.HostVector();
weights.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
auto valid = std::all_of(weights.cbegin(), weights.cend(),
[](float w) { return w >= 0; });
auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) {
return w < 0 || std::isinf(w) || std::isnan(w);
});
CHECK(valid) << "Weights must be positive values.";
} else if (!std::strcmp(key, "base_margin")) {
auto& base_margin = base_margin_.HostVector();
@@ -419,8 +424,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
bool valid =
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
[](float w) { return w >= 0; });
std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
[](float w) { return w < 0; });
CHECK(valid) << "Feature weight must be greater than 0.";
} else {
LOG(FATAL) << "Unknown key for MetaInfo: " << key;

View File

@@ -114,10 +114,11 @@ void CopyQidImpl(ArrayInterface array_interface,
namespace {
// thrust::all_of tries to copy lambda function.
struct AllOfOp {
__device__ bool operator()(float w) {
return w >= 0;
}
struct LabelsCheck {
__device__ bool operator()(float y) { return ::isnan(y) || ::isinf(y); }
};
struct WeightsCheck {
__device__ bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT
};
} // anonymous namespace
@@ -142,11 +143,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
if (key == "label") {
CopyInfoImpl(array_interface, &labels_);
auto ptr = labels_.ConstDevicePointer();
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(),
LabelsCheck{});
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
} else if (key == "weight") {
CopyInfoImpl(array_interface, &weights_);
auto ptr = weights_.ConstDevicePointer();
auto valid =
thrust::all_of(thrust::device, ptr, ptr + weights_.Size(), AllOfOp{});
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(),
WeightsCheck{});
CHECK(valid) << "Weights must be positive values.";
} else if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
@@ -165,9 +170,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "feature_weights") {
CopyInfoImpl(array_interface, &feature_weigths);
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
auto valid = thrust::all_of(
auto valid = thrust::none_of(
thrust::device, d_feature_weights.data(),
d_feature_weights.data() + d_feature_weights.size(), AllOfOp{});
d_feature_weights.data() + d_feature_weights.size(), WeightsCheck{});
CHECK(valid) << "Feature weight must be greater than 0.";
return;
} else {