Pass scikit learn estimator checks for regressor. (#7130)
* Check data shape. * Check labels.
This commit is contained in:
@@ -360,13 +360,18 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
labels.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
||||
auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) {
|
||||
return std::isnan(y) || std::isinf(y);
|
||||
});
|
||||
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
||||
} else if (!std::strcmp(key, "weight")) {
|
||||
auto& weights = weights_.HostVector();
|
||||
weights.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
||||
auto valid = std::all_of(weights.cbegin(), weights.cend(),
|
||||
[](float w) { return w >= 0; });
|
||||
auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) {
|
||||
return w < 0 || std::isinf(w) || std::isnan(w);
|
||||
});
|
||||
CHECK(valid) << "Weights must be positive values.";
|
||||
} else if (!std::strcmp(key, "base_margin")) {
|
||||
auto& base_margin = base_margin_.HostVector();
|
||||
@@ -419,8 +424,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
|
||||
bool valid =
|
||||
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
|
||||
[](float w) { return w >= 0; });
|
||||
std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
|
||||
[](float w) { return w < 0; });
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
|
||||
|
||||
Reference in New Issue
Block a user