diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ad99ed17c..14558b036 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1068,7 +1068,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m return ret.value def num_nonmissing(self) -> int: - """Get the number of non-missing values in the DMatrix.""" + """Get the number of non-missing values in the DMatrix. + + .. versionadded:: 1.7.0 + + """ ret = c_bst_ulong() _check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret))) return ret.value diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index f126af52b..6b193af7e 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -34,7 +34,8 @@ from .core import ( ) DispatchedDataBackendReturnType = Tuple[ - ctypes.c_void_p, Optional[FeatureNames], Optional[FeatureTypes]] + ctypes.c_void_p, Optional[FeatureNames], Optional[FeatureTypes] +] CAT_T = "c" @@ -217,27 +218,36 @@ def _is_modin_df(data: DataType) -> bool: _pandas_dtype_mapper = { - 'int8': 'int', - 'int16': 'int', - 'int32': 'int', - 'int64': 'int', - 'uint8': 'int', - 'uint16': 'int', - 'uint32': 'int', - 'uint64': 'int', - 'float16': 'float', - 'float32': 'float', - 'float64': 'float', - 'bool': 'i', - # nullable types + "int8": "int", + "int16": "int", + "int32": "int", + "int64": "int", + "uint8": "int", + "uint16": "int", + "uint32": "int", + "uint64": "int", + "float16": "float", + "float32": "float", + "float64": "float", + "bool": "i", +} + +# nullable types +pandas_nullable_mapper = { + "Int8": "int", "Int16": "int", "Int32": "int", "Int64": "int", + "UInt8": "i", + "UInt16": "i", + "UInt32": "i", + "UInt64": "i", "Float32": "float", "Float64": "float", "boolean": "i", } +_pandas_dtype_mapper.update(pandas_nullable_mapper) _ENABLE_CAT_ERR = ( "When categorical type is supplied, The experimental DMatrix parameter" @@ -304,27 +314,27 @@ def is_nullable_dtype(dtype: PandasDType) -> bool: is_integer_dtype, ) - # dtype: pd.core.arrays.numeric.NumericDtype - nullable_alias = {"Int16", "Int32", "Int64", "Float32", "Float64", "category"} - is_int = is_integer_dtype(dtype) and dtype.name in nullable_alias + is_int = is_integer_dtype(dtype) and dtype.name in pandas_nullable_mapper # np.bool has alias `bool`, while pd.BooleanDtype has `bzoolean`. is_bool = is_bool_dtype(dtype) and dtype.name == "boolean" - is_float = is_float_dtype(dtype) and dtype.name in nullable_alias + is_float = is_float_dtype(dtype) and dtype.name in pandas_nullable_mapper return is_int or is_bool or is_float or is_categorical_dtype(dtype) -def _pandas_cat_null(data: DataFrame) -> DataFrame: +def pandas_cat_null(data: DataFrame) -> DataFrame: + """Handle categorical dtype and nullable extension types from pandas.""" from pandas.api.types import is_categorical_dtype # handle category codes and nullable. - cat_columns = [ - col - for col, dtype in zip(data.columns, data.dtypes) - if is_categorical_dtype(dtype) - ] - nul_columns = [ - col for col, dtype in zip(data.columns, data.dtypes) if is_nullable_dtype(dtype) - ] + cat_columns = [] + nul_columns = [] + for col, dtype in zip(data.columns, data.dtypes): + if is_categorical_dtype(dtype): + cat_columns.append(col) + # avoid an unnecessary conversion if possible + elif is_nullable_dtype(dtype): + nul_columns.append(col) + if cat_columns or nul_columns: # Avoid transformation due to: PerformanceWarning: DataFrame is highly # fragmented @@ -333,7 +343,7 @@ def _pandas_cat_null(data: DataFrame) -> DataFrame: transformed = data if cat_columns: - # DF doesn't have the cat attribute, so we use apply here + # DF doesn't have the cat attribute, as a result, we use apply here transformed[cat_columns] = ( transformed[cat_columns] .apply(lambda x: x.cat.codes) @@ -343,6 +353,10 @@ def _pandas_cat_null(data: DataFrame) -> DataFrame: if nul_columns: transformed[nul_columns] = transformed[nul_columns].astype(np.float32) + # TODO(jiamingy): Investigate the possibility of using dataframe protocol or arrow + # IPC format for pandas so that we can apply the data transformation inside XGBoost + # for better memory efficiency. + return transformed @@ -357,9 +371,8 @@ def _transform_pandas_df( from pandas.api.types import is_categorical_dtype, is_sparse if not all( - dtype.name in _pandas_dtype_mapper + (dtype.name in _pandas_dtype_mapper) or is_sparse(dtype) - or (is_nullable_dtype(dtype) and not is_categorical_dtype(dtype)) or (is_categorical_dtype(dtype) and enable_categorical) for dtype in data.dtypes ): @@ -369,7 +382,7 @@ def _transform_pandas_df( data, meta, feature_names, feature_types, enable_categorical ) - transformed = _pandas_cat_null(data) + transformed = pandas_cat_null(data) if meta and len(data.columns) > 1 and meta not in _matrix_meta: raise ValueError(f"DataFrame for {meta} cannot have multiple columns") @@ -404,14 +417,12 @@ def _is_pandas_series(data: DataType) -> bool: def _meta_from_pandas_series( - data: DataType, - name: str, - dtype: Optional[NumpyDType], - handle: ctypes.c_void_p + data: DataType, name: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p ) -> None: """Help transform pandas series for meta data like labels""" - data = data.values.astype('float') + data = data.values.astype("float") from pandas.api.types import is_sparse + if is_sparse(data): data = data.to_dense() # type: ignore assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 7c64f499a..4eeaad6de 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -773,6 +773,19 @@ def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool: return all((y - x) < tolerance for x, y in zip(L, L[1:])) +def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool: + """Assert whether two DMatrices contain the same predictors.""" + lcsr = lhs.get_data() + rcsr = rhs.get_data() + return all( + ( + np.array_equal(lcsr.data, rcsr.data), + np.array_equal(lcsr.indices, rcsr.indices), + np.array_equal(lcsr.indptr, rcsr.indptr), + ) + ) + + def eval_error_metric(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, np.float64]: """Evaluation metric for xgb.train""" label = dtrain.get_label() diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 5dc032074..7d63097dc 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -1,5 +1,5 @@ """Utilities for data generation.""" -from typing import Generator, Tuple +from typing import Any, Generator, Tuple, Union import numpy as np @@ -7,7 +7,7 @@ import numpy as np def np_dtypes( n_samples: int, n_features: int ) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: - """Generate all supported dtypes from numpy.""" + """Enumerate all supported dtypes from numpy.""" import pandas as pd rng = np.random.RandomState(1994) @@ -60,3 +60,61 @@ def np_dtypes( df_orig = pd.DataFrame(orig) df = pd.DataFrame(X) yield df_orig, df + + +def pd_dtypes() -> Generator: + """Enumerate all supported pandas extension types.""" + import pandas as pd + + # Integer + dtypes = [ + pd.UInt8Dtype(), + pd.UInt16Dtype(), + pd.UInt32Dtype(), + pd.UInt64Dtype(), + pd.Int8Dtype(), + pd.Int16Dtype(), + pd.Int32Dtype(), + pd.Int64Dtype(), + ] + + Null: Union[float, None, Any] = np.nan + orig = pd.DataFrame( + {"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=np.float32 + ) + for Null in (np.nan, None, pd.NA): + for dtype in dtypes: + df = pd.DataFrame( + {"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=dtype + ) + yield orig, df + + # Float + Null = np.nan + dtypes = [pd.Float32Dtype(), pd.Float64Dtype()] + orig = pd.DataFrame( + {"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=np.float32 + ) + for Null in (np.nan, None, pd.NA): + for dtype in dtypes: + df = pd.DataFrame( + {"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=dtype + ) + yield orig, df + + # Categorical + orig = orig.astype("category") + for Null in (np.nan, None, pd.NA): + df = pd.DataFrame( + {"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, + dtype=pd.CategoricalDtype(), + ) + yield orig, df + + # Boolean + for Null in [None, pd.NA]: + data = {"f0": [True, False, Null, True], "f1": [False, True, Null, True]} + # pd.NA is not convertible to bool. + orig = pd.DataFrame(data, dtype=np.bool_ if Null is None else pd.BooleanDtype()) + df = pd.DataFrame(data, dtype=pd.BooleanDtype()) + yield orig, df diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index f192f813e..3fcc62967 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -461,8 +461,4 @@ class TestDMatrix: for orig, x in np_dtypes(n_samples, n_features): m0 = xgb.DMatrix(orig) m1 = xgb.DMatrix(x) - csr0 = m0.get_data() - csr1 = m1.get_data() - np.testing.assert_allclose(csr0.data, csr1.data) - np.testing.assert_allclose(csr0.indptr, csr1.indptr) - np.testing.assert_allclose(csr0.indices, csr1.indices) + assert tm.predictor_equal(m0, m1) diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index e62137e36..476ad775e 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -10,6 +10,7 @@ from xgboost.testing import ( make_batches_sparse, make_categorical, make_sparse_regression, + predictor_equal, ) from xgboost.testing.data import np_dtypes @@ -246,11 +247,7 @@ class TestQuantileDMatrix: for orig, x in np_dtypes(n_samples, n_features): m0 = xgb.QuantileDMatrix(orig) m1 = xgb.QuantileDMatrix(x) - csr0 = m0.get_data() - csr1 = m1.get_data() - np.testing.assert_allclose(csr0.data, csr1.data) - np.testing.assert_allclose(csr0.indptr, csr1.indptr) - np.testing.assert_allclose(csr0.indices, csr1.indices) + assert predictor_equal(m0, m1) # unsupported types for dtype in [ diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 209e5cf6f..863569691 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -4,6 +4,7 @@ import tempfile import numpy as np import pytest from test_dmatrix import set_base_margin_info +from xgboost.testing.data import pd_dtypes import xgboost as xgb from xgboost import testing as tm @@ -297,70 +298,22 @@ class TestPandas: assert 'auc' not in cv.columns[0] assert 'error' in cv.columns[0] - def test_nullable_type(self): - y = np.random.default_rng(0).random(4) + def test_nullable_type(self) -> None: + from pandas.api.types import is_categorical - def to_bytes(Xy: xgb.DMatrix) -> bytes: - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "Xy.dmatrix") - Xy.save_binary(path) - with open(path, "rb") as fd: - result = fd.read() - return result + for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix): + for orig, df in pd_dtypes(): + enable_categorical = any(is_categorical for dtype in df.dtypes) - def test_int(dtype) -> bytes: - arr = pd.DataFrame( - {"f0": [1, 2, None, 3], "f1": [4, 3, None, 1]}, dtype=dtype - ) - Xy = xgb.DMatrix(arr, y) - Xy.feature_types = None - return to_bytes(Xy) + m_orig = DMatrixT(orig, enable_categorical=enable_categorical) + # extension types + m_etype = DMatrixT(df, enable_categorical=enable_categorical) + # different from pd.BooleanDtype(), None is converted to False with bool + if any(dtype == "bool" for dtype in orig.dtypes): + assert not tm.predictor_equal(m_orig, m_etype) + else: + assert tm.predictor_equal(m_orig, m_etype) - b0 = test_int(np.float32) - b1 = test_int(pd.Int16Dtype()) - assert b0 == b1 - - def test_bool(dtype) -> bytes: - arr = pd.DataFrame( - {"f0": [True, False, None, True], "f1": [False, True, None, True]}, - dtype=dtype, - ) - Xy = xgb.DMatrix(arr, y) - Xy.feature_types = None - return to_bytes(Xy) - - b0 = test_bool(pd.BooleanDtype()) - b1 = test_bool(bool) - assert b0 != b1 # None is converted to False with np.bool - - data = {"f0": [1.0, 2.0, None, 3.0], "f1": [3.0, 2.0, None, 1.0]} - - arr = np.array([data["f0"], data["f1"]]).T - Xy = xgb.DMatrix(arr, y) - Xy.feature_types = None - Xy.feature_names = None - from_np = to_bytes(Xy) - - def test_float(dtype) -> bytes: - arr = pd.DataFrame(data, dtype=dtype) - Xy = xgb.DMatrix(arr, y) - Xy.feature_types = None - Xy.feature_names = None - return to_bytes(Xy) - - b0 = test_float(pd.Float64Dtype()) - b1 = test_float(float) - assert b0 == b1 # both are converted to NaN - assert b0 == from_np - - def test_cat(dtype) -> bytes: - arr = pd.DataFrame(data, dtype=dtype) - if dtype is None: - arr = arr.astype("category") - Xy = xgb.DMatrix(arr, y, enable_categorical=True) - Xy.feature_types = None - return to_bytes(Xy) - - b0 = test_cat(pd.CategoricalDtype()) - b1 = test_cat(None) - assert b0 == b1 + f0 = df["f0"] + with pytest.raises(ValueError, match="Label contains NaN"): + xgb.DMatrix(df, f0, enable_categorical=enable_categorical)