Pass scikit learn estimator checks for regressor. (#7130)
* Check data shape. * Check labels.
This commit is contained in:
parent
8ee127469f
commit
8a84be37b8
@ -584,8 +584,6 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
`gpu_predictor` and pandas input are required.
|
`gpu_predictor` and pandas input are required.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, list):
|
|
||||||
raise TypeError("Input data can not be a list.")
|
|
||||||
if group is not None and qid is not None:
|
if group is not None and qid is not None:
|
||||||
raise ValueError("Either one of `group` or `qid` should be None.")
|
raise ValueError("Either one of `group` or `qid` should be None.")
|
||||||
|
|
||||||
@ -2005,6 +2003,10 @@ class Booster(object):
|
|||||||
p_handle = ctypes.c_void_p()
|
p_handle = ctypes.c_void_p()
|
||||||
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
||||||
if validate_features:
|
if validate_features:
|
||||||
|
if not hasattr(data, "shape"):
|
||||||
|
raise TypeError(
|
||||||
|
"`shape` attribute is required when `validate_features` is True."
|
||||||
|
)
|
||||||
if len(data.shape) != 1 and self.num_features() != data.shape[1]:
|
if len(data.shape) != 1 and self.num_features() != data.shape[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Feature shape mismatch, expected: {self.num_features()}, "
|
f"Feature shape mismatch, expected: {self.num_features()}, "
|
||||||
|
|||||||
@ -32,6 +32,11 @@ def _check_complex(data):
|
|||||||
raise ValueError('Complex data not supported')
|
raise ValueError('Complex data not supported')
|
||||||
|
|
||||||
|
|
||||||
|
def _check_data_shape(data: Any) -> None:
|
||||||
|
if hasattr(data, "shape") and len(data.shape) != 2:
|
||||||
|
raise ValueError("Please reshape the input data into 2-dimensional matrix.")
|
||||||
|
|
||||||
|
|
||||||
def _is_scipy_csr(data):
|
def _is_scipy_csr(data):
|
||||||
try:
|
try:
|
||||||
import scipy
|
import scipy
|
||||||
@ -524,16 +529,18 @@ def _is_list(data):
|
|||||||
return isinstance(data, list)
|
return isinstance(data, list)
|
||||||
|
|
||||||
|
|
||||||
def _from_list(data, missing, feature_names, feature_types):
|
def _from_list(data, missing, n_threads, feature_names, feature_types):
|
||||||
raise TypeError('List input data is not supported for data')
|
array = np.array(data)
|
||||||
|
_check_data_shape(data)
|
||||||
|
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
def _is_tuple(data):
|
def _is_tuple(data):
|
||||||
return isinstance(data, tuple)
|
return isinstance(data, tuple)
|
||||||
|
|
||||||
|
|
||||||
def _from_tuple(data, missing, feature_names, feature_types):
|
def _from_tuple(data, missing, n_threads, feature_names, feature_types):
|
||||||
return _from_list(data, missing, feature_names, feature_types)
|
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
def _is_iter(data):
|
def _is_iter(data):
|
||||||
@ -566,6 +573,8 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
feature_names, feature_types,
|
feature_names, feature_types,
|
||||||
enable_categorical=False):
|
enable_categorical=False):
|
||||||
'''Dispatch data for DMatrix.'''
|
'''Dispatch data for DMatrix.'''
|
||||||
|
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||||
|
_check_data_shape(data)
|
||||||
if _is_scipy_csr(data):
|
if _is_scipy_csr(data):
|
||||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||||
if _is_scipy_csc(data):
|
if _is_scipy_csc(data):
|
||||||
@ -578,9 +587,9 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
if _is_uri(data):
|
if _is_uri(data):
|
||||||
return _from_uri(data, missing, feature_names, feature_types)
|
return _from_uri(data, missing, feature_names, feature_types)
|
||||||
if _is_list(data):
|
if _is_list(data):
|
||||||
return _from_list(data, missing, feature_names, feature_types)
|
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||||
if _is_tuple(data):
|
if _is_tuple(data):
|
||||||
return _from_tuple(data, missing, feature_names, feature_types)
|
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
return _from_pandas_df(data, enable_categorical, missing, threads,
|
return _from_pandas_df(data, enable_categorical, missing, threads,
|
||||||
feature_names, feature_types)
|
feature_names, feature_types)
|
||||||
@ -612,11 +621,12 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
return _from_pandas_series(data, missing, threads, feature_names,
|
return _from_pandas_series(data, missing, threads, feature_names,
|
||||||
feature_types)
|
feature_types)
|
||||||
if _has_array_protocol(data):
|
if _has_array_protocol(data):
|
||||||
pass
|
array = np.asarray(data)
|
||||||
|
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
||||||
|
|
||||||
converted = _convert_unknown_data(data)
|
converted = _convert_unknown_data(data)
|
||||||
if converted:
|
if converted is not None:
|
||||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
return _from_scipy_csr(converted, missing, threads, feature_names, feature_types)
|
||||||
|
|
||||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||||
|
|
||||||
@ -630,11 +640,12 @@ def _to_data_type(dtype: str, name: str):
|
|||||||
return dtype_map[dtype]
|
return dtype_map[dtype]
|
||||||
|
|
||||||
|
|
||||||
def _validate_meta_shape(data):
|
def _validate_meta_shape(data, name: str) -> None:
|
||||||
if hasattr(data, 'shape'):
|
if hasattr(data, "shape"):
|
||||||
assert len(data.shape) == 1 or (
|
if len(data.shape) > 2 or (
|
||||||
len(data.shape) == 2 and
|
len(data.shape) == 2 and (data.shape[1] != 0 and data.shape[1] != 1)
|
||||||
(data.shape[1] == 0 or data.shape[1] == 1))
|
):
|
||||||
|
raise ValueError(f"Invalid shape: {data.shape} for {name}")
|
||||||
|
|
||||||
|
|
||||||
def _meta_from_numpy(data, field, dtype, handle):
|
def _meta_from_numpy(data, field, dtype, handle):
|
||||||
@ -702,7 +713,7 @@ def _meta_from_dt(data, field, dtype, handle):
|
|||||||
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
||||||
'''Dispatch for meta info.'''
|
'''Dispatch for meta info.'''
|
||||||
handle = matrix.handle
|
handle = matrix.handle
|
||||||
_validate_meta_shape(data)
|
_validate_meta_shape(data, name)
|
||||||
if data is None:
|
if data is None:
|
||||||
return
|
return
|
||||||
if _is_list(data):
|
if _is_list(data):
|
||||||
@ -751,7 +762,9 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
|||||||
_meta_from_numpy(data, name, dtype, handle)
|
_meta_from_numpy(data, name, dtype, handle)
|
||||||
return
|
return
|
||||||
if _has_array_protocol(data):
|
if _has_array_protocol(data):
|
||||||
pass
|
array = np.asarray(data)
|
||||||
|
_meta_from_numpy(array, name, dtype, handle)
|
||||||
|
return
|
||||||
raise TypeError('Unsupported type for ' + name, str(type(data)))
|
raise TypeError('Unsupported type for ' + name, str(type(data)))
|
||||||
|
|
||||||
|
|
||||||
@ -802,6 +815,8 @@ def _proxy_transform(data, feature_names, feature_types, enable_categorical):
|
|||||||
|
|
||||||
def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) -> None:
|
def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) -> None:
|
||||||
"""Dispatch for DeviceQuantileDMatrix."""
|
"""Dispatch for DeviceQuantileDMatrix."""
|
||||||
|
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||||
|
_check_data_shape(data)
|
||||||
if _is_cudf_df(data):
|
if _is_cudf_df(data):
|
||||||
proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212
|
proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212
|
||||||
return
|
return
|
||||||
|
|||||||
@ -419,7 +419,6 @@ class XGBModel(XGBModelBase):
|
|||||||
self.base_score = base_score
|
self.base_score = base_score
|
||||||
self.missing = missing
|
self.missing = missing
|
||||||
self.num_parallel_tree = num_parallel_tree
|
self.num_parallel_tree = num_parallel_tree
|
||||||
self.kwargs = kwargs
|
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.monotone_constraints = monotone_constraints
|
self.monotone_constraints = monotone_constraints
|
||||||
@ -429,6 +428,8 @@ class XGBModel(XGBModelBase):
|
|||||||
self.validate_parameters = validate_parameters
|
self.validate_parameters = validate_parameters
|
||||||
self.predictor = predictor
|
self.predictor = predictor
|
||||||
self.enable_categorical = enable_categorical
|
self.enable_categorical = enable_categorical
|
||||||
|
if kwargs:
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def _more_tags(self) -> Dict[str, bool]:
|
def _more_tags(self) -> Dict[str, bool]:
|
||||||
'''Tags used for scikit-learn data validation.'''
|
'''Tags used for scikit-learn data validation.'''
|
||||||
@ -469,6 +470,8 @@ class XGBModel(XGBModelBase):
|
|||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
else:
|
else:
|
||||||
|
if not hasattr(self, "kwargs"):
|
||||||
|
self.kwargs = {}
|
||||||
self.kwargs[key] = value
|
self.kwargs[key] = value
|
||||||
|
|
||||||
if hasattr(self, '_Booster'):
|
if hasattr(self, '_Booster'):
|
||||||
@ -491,7 +494,7 @@ class XGBModel(XGBModelBase):
|
|||||||
cp.__class__ = cp.__class__.__bases__[0]
|
cp.__class__ = cp.__class__.__bases__[0]
|
||||||
params.update(cp.__class__.get_params(cp, deep))
|
params.update(cp.__class__.get_params(cp, deep))
|
||||||
# if kwargs is a dict, update params accordingly
|
# if kwargs is a dict, update params accordingly
|
||||||
if isinstance(self.kwargs, dict):
|
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
|
||||||
params.update(self.kwargs)
|
params.update(self.kwargs)
|
||||||
if isinstance(params['random_state'], np.random.RandomState):
|
if isinstance(params['random_state'], np.random.RandomState):
|
||||||
params['random_state'] = params['random_state'].randint(
|
params['random_state'] = params['random_state'].randint(
|
||||||
@ -745,7 +748,6 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
evals_result: TrainingCallback.EvalsLog = {}
|
evals_result: TrainingCallback.EvalsLog = {}
|
||||||
|
|
||||||
train_dmatrix, evals = _wrap_evaluation_matrices(
|
train_dmatrix, evals = _wrap_evaluation_matrices(
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
X=X,
|
X=X,
|
||||||
@ -1169,7 +1171,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
):
|
):
|
||||||
raise ValueError(label_encoding_check_error)
|
raise ValueError(label_encoding_check_error)
|
||||||
else:
|
else:
|
||||||
self.classes_ = np.unique(y)
|
self.classes_ = np.unique(np.asarray(y))
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
if not self.use_label_encoder and (
|
if not self.use_label_encoder and (
|
||||||
not np.array_equal(self.classes_, np.arange(self.n_classes_))
|
not np.array_equal(self.classes_, np.arange(self.n_classes_))
|
||||||
@ -1206,11 +1208,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
label_transform = lambda x: x
|
label_transform = lambda x: x
|
||||||
|
|
||||||
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||||
if len(X.shape) != 2:
|
|
||||||
# Simply raise an error here since there might be many
|
|
||||||
# different ways of reshaping
|
|
||||||
raise ValueError("Please reshape the input data X into 2-dimensional matrix.")
|
|
||||||
|
|
||||||
train_dmatrix, evals = _wrap_evaluation_matrices(
|
train_dmatrix, evals = _wrap_evaluation_matrices(
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
X=X,
|
X=X,
|
||||||
|
|||||||
@ -360,13 +360,18 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
|||||||
labels.resize(num);
|
labels.resize(num);
|
||||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||||
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
||||||
|
auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) {
|
||||||
|
return std::isnan(y) || std::isinf(y);
|
||||||
|
});
|
||||||
|
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
||||||
} else if (!std::strcmp(key, "weight")) {
|
} else if (!std::strcmp(key, "weight")) {
|
||||||
auto& weights = weights_.HostVector();
|
auto& weights = weights_.HostVector();
|
||||||
weights.resize(num);
|
weights.resize(num);
|
||||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||||
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
||||||
auto valid = std::all_of(weights.cbegin(), weights.cend(),
|
auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) {
|
||||||
[](float w) { return w >= 0; });
|
return w < 0 || std::isinf(w) || std::isnan(w);
|
||||||
|
});
|
||||||
CHECK(valid) << "Weights must be positive values.";
|
CHECK(valid) << "Weights must be positive values.";
|
||||||
} else if (!std::strcmp(key, "base_margin")) {
|
} else if (!std::strcmp(key, "base_margin")) {
|
||||||
auto& base_margin = base_margin_.HostVector();
|
auto& base_margin = base_margin_.HostVector();
|
||||||
@ -419,8 +424,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
|||||||
dtype, dptr, cast_dptr,
|
dtype, dptr, cast_dptr,
|
||||||
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
|
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
|
||||||
bool valid =
|
bool valid =
|
||||||
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
|
std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
|
||||||
[](float w) { return w >= 0; });
|
[](float w) { return w < 0; });
|
||||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
|
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
|
||||||
|
|||||||
@ -114,10 +114,11 @@ void CopyQidImpl(ArrayInterface array_interface,
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// thrust::all_of tries to copy lambda function.
|
// thrust::all_of tries to copy lambda function.
|
||||||
struct AllOfOp {
|
struct LabelsCheck {
|
||||||
__device__ bool operator()(float w) {
|
__device__ bool operator()(float y) { return ::isnan(y) || ::isinf(y); }
|
||||||
return w >= 0;
|
};
|
||||||
}
|
struct WeightsCheck {
|
||||||
|
__device__ bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
@ -142,11 +143,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
|
|
||||||
if (key == "label") {
|
if (key == "label") {
|
||||||
CopyInfoImpl(array_interface, &labels_);
|
CopyInfoImpl(array_interface, &labels_);
|
||||||
|
auto ptr = labels_.ConstDevicePointer();
|
||||||
|
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(),
|
||||||
|
LabelsCheck{});
|
||||||
|
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
||||||
} else if (key == "weight") {
|
} else if (key == "weight") {
|
||||||
CopyInfoImpl(array_interface, &weights_);
|
CopyInfoImpl(array_interface, &weights_);
|
||||||
auto ptr = weights_.ConstDevicePointer();
|
auto ptr = weights_.ConstDevicePointer();
|
||||||
auto valid =
|
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(),
|
||||||
thrust::all_of(thrust::device, ptr, ptr + weights_.Size(), AllOfOp{});
|
WeightsCheck{});
|
||||||
CHECK(valid) << "Weights must be positive values.";
|
CHECK(valid) << "Weights must be positive values.";
|
||||||
} else if (key == "base_margin") {
|
} else if (key == "base_margin") {
|
||||||
CopyInfoImpl(array_interface, &base_margin_);
|
CopyInfoImpl(array_interface, &base_margin_);
|
||||||
@ -165,9 +170,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
} else if (key == "feature_weights") {
|
} else if (key == "feature_weights") {
|
||||||
CopyInfoImpl(array_interface, &feature_weigths);
|
CopyInfoImpl(array_interface, &feature_weigths);
|
||||||
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
||||||
auto valid = thrust::all_of(
|
auto valid = thrust::none_of(
|
||||||
thrust::device, d_feature_weights.data(),
|
thrust::device, d_feature_weights.data(),
|
||||||
d_feature_weights.data() + d_feature_weights.size(), AllOfOp{});
|
d_feature_weights.data() + d_feature_weights.size(), WeightsCheck{});
|
||||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -330,3 +330,12 @@ class TestDMatrix:
|
|||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
d = Data()
|
d = Data()
|
||||||
xgb.DMatrix(d)
|
xgb.DMatrix(d)
|
||||||
|
|
||||||
|
from scipy import sparse
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
X = rng.rand(10, 10)
|
||||||
|
y = rng.rand(10)
|
||||||
|
X = sparse.dok_matrix(X)
|
||||||
|
Xy = xgb.DMatrix(X, y)
|
||||||
|
assert Xy.num_row() == 10
|
||||||
|
assert Xy.num_col() == 10
|
||||||
|
|||||||
@ -13,6 +13,8 @@ rng = np.random.RandomState(1994)
|
|||||||
|
|
||||||
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
|
||||||
|
from sklearn.utils.estimator_checks import parametrize_with_checks
|
||||||
|
|
||||||
|
|
||||||
class TemporaryDirectory(object):
|
class TemporaryDirectory(object):
|
||||||
"""Context manager for tempfile.mkdtemp()"""
|
"""Context manager for tempfile.mkdtemp()"""
|
||||||
@ -1223,3 +1225,32 @@ def test_data_initialization():
|
|||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
X, y = load_digits(return_X_y=True)
|
X, y = load_digits(return_X_y=True)
|
||||||
run_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y)
|
run_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y)
|
||||||
|
|
||||||
|
|
||||||
|
@parametrize_with_checks([xgb.XGBRegressor()])
|
||||||
|
def test_estimator_reg(estimator, check):
|
||||||
|
if os.environ["PYTEST_CURRENT_TEST"].find("check_supervised_y_no_nan") != -1:
|
||||||
|
# The test uses float64 and requires the error message to contain:
|
||||||
|
#
|
||||||
|
# "value too large for dtype(float64)",
|
||||||
|
#
|
||||||
|
# while XGBoost stores values as float32. But XGBoost does verify the label
|
||||||
|
# internally, so we replace this test with custom check.
|
||||||
|
rng = np.random.RandomState(888)
|
||||||
|
X = rng.randn(10, 5)
|
||||||
|
y = np.full(10, np.inf)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="contains NaN, infinity or a value too large"
|
||||||
|
):
|
||||||
|
estimator.fit(X, y)
|
||||||
|
return
|
||||||
|
if os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params") != -1:
|
||||||
|
# A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor
|
||||||
|
# returns actual internal default values for parameters in `get_params`, but those
|
||||||
|
# are set as `None` in sklearn interface to avoid duplication. So we fit a dummy
|
||||||
|
# model and obtain the default parameters here for the mutation tests.
|
||||||
|
from sklearn.datasets import make_regression
|
||||||
|
X, y = make_regression(n_samples=2, n_features=1)
|
||||||
|
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
|
||||||
|
|
||||||
|
check(estimator)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user