diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 4b5384388..e25a15e43 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -584,8 +584,6 @@ class DMatrix: # pylint: disable=too-many-instance-attributes `gpu_predictor` and pandas input are required. """ - if isinstance(data, list): - raise TypeError("Input data can not be a list.") if group is not None and qid is not None: raise ValueError("Either one of `group` or `qid` should be None.") @@ -2005,6 +2003,10 @@ class Booster(object): p_handle = ctypes.c_void_p() assert proxy is None or isinstance(proxy, _ProxyDMatrix) if validate_features: + if not hasattr(data, "shape"): + raise TypeError( + "`shape` attribute is required when `validate_features` is True." + ) if len(data.shape) != 1 and self.num_features() != data.shape[1]: raise ValueError( f"Feature shape mismatch, expected: {self.num_features()}, " diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 501a2bee7..7df15dd82 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -32,6 +32,11 @@ def _check_complex(data): raise ValueError('Complex data not supported') +def _check_data_shape(data: Any) -> None: + if hasattr(data, "shape") and len(data.shape) != 2: + raise ValueError("Please reshape the input data into 2-dimensional matrix.") + + def _is_scipy_csr(data): try: import scipy @@ -524,16 +529,18 @@ def _is_list(data): return isinstance(data, list) -def _from_list(data, missing, feature_names, feature_types): - raise TypeError('List input data is not supported for data') +def _from_list(data, missing, n_threads, feature_names, feature_types): + array = np.array(data) + _check_data_shape(data) + return _from_numpy_array(array, missing, n_threads, feature_names, feature_types) def _is_tuple(data): return isinstance(data, tuple) -def _from_tuple(data, missing, feature_names, feature_types): - return _from_list(data, missing, feature_names, feature_types) +def _from_tuple(data, missing, n_threads, feature_names, feature_types): + return _from_list(data, missing, n_threads, feature_names, feature_types) def _is_iter(data): @@ -566,6 +573,8 @@ def dispatch_data_backend(data, missing, threads, feature_names, feature_types, enable_categorical=False): '''Dispatch data for DMatrix.''' + if not _is_cudf_ser(data) and not _is_pandas_series(data): + _check_data_shape(data) if _is_scipy_csr(data): return _from_scipy_csr(data, missing, threads, feature_names, feature_types) if _is_scipy_csc(data): @@ -578,9 +587,9 @@ def dispatch_data_backend(data, missing, threads, if _is_uri(data): return _from_uri(data, missing, feature_names, feature_types) if _is_list(data): - return _from_list(data, missing, feature_names, feature_types) + return _from_list(data, missing, threads, feature_names, feature_types) if _is_tuple(data): - return _from_tuple(data, missing, feature_names, feature_types) + return _from_tuple(data, missing, threads, feature_names, feature_types) if _is_pandas_df(data): return _from_pandas_df(data, enable_categorical, missing, threads, feature_names, feature_types) @@ -612,11 +621,12 @@ def dispatch_data_backend(data, missing, threads, return _from_pandas_series(data, missing, threads, feature_names, feature_types) if _has_array_protocol(data): - pass + array = np.asarray(data) + return _from_numpy_array(array, missing, threads, feature_names, feature_types) converted = _convert_unknown_data(data) - if converted: - return _from_scipy_csr(data, missing, threads, feature_names, feature_types) + if converted is not None: + return _from_scipy_csr(converted, missing, threads, feature_names, feature_types) raise TypeError('Not supported type for data.' + str(type(data))) @@ -630,11 +640,12 @@ def _to_data_type(dtype: str, name: str): return dtype_map[dtype] -def _validate_meta_shape(data): - if hasattr(data, 'shape'): - assert len(data.shape) == 1 or ( - len(data.shape) == 2 and - (data.shape[1] == 0 or data.shape[1] == 1)) +def _validate_meta_shape(data, name: str) -> None: + if hasattr(data, "shape"): + if len(data.shape) > 2 or ( + len(data.shape) == 2 and (data.shape[1] != 0 and data.shape[1] != 1) + ): + raise ValueError(f"Invalid shape: {data.shape} for {name}") def _meta_from_numpy(data, field, dtype, handle): @@ -702,7 +713,7 @@ def _meta_from_dt(data, field, dtype, handle): def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): '''Dispatch for meta info.''' handle = matrix.handle - _validate_meta_shape(data) + _validate_meta_shape(data, name) if data is None: return if _is_list(data): @@ -751,7 +762,9 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): _meta_from_numpy(data, name, dtype, handle) return if _has_array_protocol(data): - pass + array = np.asarray(data) + _meta_from_numpy(array, name, dtype, handle) + return raise TypeError('Unsupported type for ' + name, str(type(data))) @@ -802,6 +815,8 @@ def _proxy_transform(data, feature_names, feature_types, enable_categorical): def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) -> None: """Dispatch for DeviceQuantileDMatrix.""" + if not _is_cudf_ser(data) and not _is_pandas_series(data): + _check_data_shape(data) if _is_cudf_df(data): proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 return diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 257cba0f9..cbce3dc05 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -419,7 +419,6 @@ class XGBModel(XGBModelBase): self.base_score = base_score self.missing = missing self.num_parallel_tree = num_parallel_tree - self.kwargs = kwargs self.random_state = random_state self.n_jobs = n_jobs self.monotone_constraints = monotone_constraints @@ -429,6 +428,8 @@ class XGBModel(XGBModelBase): self.validate_parameters = validate_parameters self.predictor = predictor self.enable_categorical = enable_categorical + if kwargs: + self.kwargs = kwargs def _more_tags(self) -> Dict[str, bool]: '''Tags used for scikit-learn data validation.''' @@ -469,6 +470,8 @@ class XGBModel(XGBModelBase): if hasattr(self, key): setattr(self, key, value) else: + if not hasattr(self, "kwargs"): + self.kwargs = {} self.kwargs[key] = value if hasattr(self, '_Booster'): @@ -491,7 +494,7 @@ class XGBModel(XGBModelBase): cp.__class__ = cp.__class__.__bases__[0] params.update(cp.__class__.get_params(cp, deep)) # if kwargs is a dict, update params accordingly - if isinstance(self.kwargs, dict): + if hasattr(self, "kwargs") and isinstance(self.kwargs, dict): params.update(self.kwargs) if isinstance(params['random_state'], np.random.RandomState): params['random_state'] = params['random_state'].randint( @@ -745,7 +748,6 @@ class XGBModel(XGBModelBase): """ evals_result: TrainingCallback.EvalsLog = {} - train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, X=X, @@ -1169,7 +1171,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): ): raise ValueError(label_encoding_check_error) else: - self.classes_ = np.unique(y) + self.classes_ = np.unique(np.asarray(y)) self.n_classes_ = len(self.classes_) if not self.use_label_encoder and ( not np.array_equal(self.classes_, np.arange(self.n_classes_)) @@ -1206,11 +1208,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): label_transform = lambda x: x model, feval, params = self._configure_fit(xgb_model, eval_metric, params) - if len(X.shape) != 2: - # Simply raise an error here since there might be many - # different ways of reshaping - raise ValueError("Please reshape the input data X into 2-dimensional matrix.") - train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, X=X, diff --git a/src/data/data.cc b/src/data/data.cc index de27f51f6..42a144c79 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -360,13 +360,18 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t labels.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, std::copy(cast_dptr, cast_dptr + num, labels.begin())); + auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) { + return std::isnan(y) || std::isinf(y); + }); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; } else if (!std::strcmp(key, "weight")) { auto& weights = weights_.HostVector(); weights.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, std::copy(cast_dptr, cast_dptr + num, weights.begin())); - auto valid = std::all_of(weights.cbegin(), weights.cend(), - [](float w) { return w >= 0; }); + auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) { + return w < 0 || std::isinf(w) || std::isnan(w); + }); CHECK(valid) << "Weights must be positive values."; } else if (!std::strcmp(key, "base_margin")) { auto& base_margin = base_margin_.HostVector(); @@ -419,8 +424,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t dtype, dptr, cast_dptr, std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin())); bool valid = - std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(), - [](float w) { return w >= 0; }); + std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), + [](float w) { return w < 0; }); CHECK(valid) << "Feature weight must be greater than 0."; } else { LOG(FATAL) << "Unknown key for MetaInfo: " << key; diff --git a/src/data/data.cu b/src/data/data.cu index de8a8c248..2c421938c 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -114,10 +114,11 @@ void CopyQidImpl(ArrayInterface array_interface, namespace { // thrust::all_of tries to copy lambda function. -struct AllOfOp { - __device__ bool operator()(float w) { - return w >= 0; - } +struct LabelsCheck { + __device__ bool operator()(float y) { return ::isnan(y) || ::isinf(y); } +}; +struct WeightsCheck { + __device__ bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT }; } // anonymous namespace @@ -142,11 +143,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { if (key == "label") { CopyInfoImpl(array_interface, &labels_); + auto ptr = labels_.ConstDevicePointer(); + auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), + LabelsCheck{}); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; } else if (key == "weight") { CopyInfoImpl(array_interface, &weights_); auto ptr = weights_.ConstDevicePointer(); - auto valid = - thrust::all_of(thrust::device, ptr, ptr + weights_.Size(), AllOfOp{}); + auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), + WeightsCheck{}); CHECK(valid) << "Weights must be positive values."; } else if (key == "base_margin") { CopyInfoImpl(array_interface, &base_margin_); @@ -165,9 +170,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { } else if (key == "feature_weights") { CopyInfoImpl(array_interface, &feature_weigths); auto d_feature_weights = feature_weigths.ConstDeviceSpan(); - auto valid = thrust::all_of( + auto valid = thrust::none_of( thrust::device, d_feature_weights.data(), - d_feature_weights.data() + d_feature_weights.size(), AllOfOp{}); + d_feature_weights.data() + d_feature_weights.size(), WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; return; } else { diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 1d201ece9..c3e6a0dad 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -330,3 +330,12 @@ class TestDMatrix: with pytest.warns(UserWarning): d = Data() xgb.DMatrix(d) + + from scipy import sparse + rng = np.random.RandomState(1994) + X = rng.rand(10, 10) + y = rng.rand(10) + X = sparse.dok_matrix(X) + Xy = xgb.DMatrix(X, y) + assert Xy.num_row() == 10 + assert Xy.num_col() == 10 diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index d44d0e3af..cf31929cc 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -13,6 +13,8 @@ rng = np.random.RandomState(1994) pytestmark = pytest.mark.skipif(**tm.no_sklearn()) +from sklearn.utils.estimator_checks import parametrize_with_checks + class TemporaryDirectory(object): """Context manager for tempfile.mkdtemp()""" @@ -1223,3 +1225,32 @@ def test_data_initialization(): from sklearn.datasets import load_digits X, y = load_digits(return_X_y=True) run_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y) + + +@parametrize_with_checks([xgb.XGBRegressor()]) +def test_estimator_reg(estimator, check): + if os.environ["PYTEST_CURRENT_TEST"].find("check_supervised_y_no_nan") != -1: + # The test uses float64 and requires the error message to contain: + # + # "value too large for dtype(float64)", + # + # while XGBoost stores values as float32. But XGBoost does verify the label + # internally, so we replace this test with custom check. + rng = np.random.RandomState(888) + X = rng.randn(10, 5) + y = np.full(10, np.inf) + with pytest.raises( + ValueError, match="contains NaN, infinity or a value too large" + ): + estimator.fit(X, y) + return + if os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params") != -1: + # A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor + # returns actual internal default values for parameters in `get_params`, but those + # are set as `None` in sklearn interface to avoid duplication. So we fit a dummy + # model and obtain the default parameters here for the mutation tests. + from sklearn.datasets import make_regression + X, y = make_regression(n_samples=2, n_features=1) + estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params()) + + check(estimator)